From a2f2f48e717d5a17e4b2d9592801eed058afe18e Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Sat, 25 Jan 2025 09:22:07 +0800 Subject: [PATCH 001/136] empty file --- python/sglang/srt/managers/generation_manager.py | 0 1 file changed, 0 insertions(+), 0 deletions(-) create mode 100644 python/sglang/srt/managers/generation_manager.py diff --git a/python/sglang/srt/managers/generation_manager.py b/python/sglang/srt/managers/generation_manager.py new file mode 100644 index 00000000000..e69de29bb2d From 815dbc30ef44381fe858f41c92e938f31f0b7c3e Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Sat, 25 Jan 2025 09:23:03 +0800 Subject: [PATCH 002/136] empty class --- python/sglang/srt/managers/generation_manager.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/python/sglang/srt/managers/generation_manager.py b/python/sglang/srt/managers/generation_manager.py index e69de29bb2d..c5e1c014e0f 100644 --- a/python/sglang/srt/managers/generation_manager.py +++ b/python/sglang/srt/managers/generation_manager.py @@ -0,0 +1,10 @@ +class GenerationManager: + pass + + +class GenerationConverter: + pass + + +class _MetricManager: + pass From 3c0e52f7214ce7ce4bb4284006df9c2ad88ac41a Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Sat, 25 Jan 2025 09:25:04 +0800 Subject: [PATCH 003/136] mv MetricManager --- .../sglang/srt/managers/generation_manager.py | 65 ++++++++++++++++++- .../sglang/srt/managers/tokenizer_manager.py | 46 +------------ 2 files changed, 66 insertions(+), 45 deletions(-) diff --git a/python/sglang/srt/managers/generation_manager.py b/python/sglang/srt/managers/generation_manager.py index c5e1c014e0f..083e63c6e74 100644 --- a/python/sglang/srt/managers/generation_manager.py +++ b/python/sglang/srt/managers/generation_manager.py @@ -1,3 +1,11 @@ +import dataclasses +import time +from typing import Optional + +from sglang.srt.metrics.collector import TokenizerMetricsCollector +from sglang.srt.server_args import ServerArgs + + class GenerationManager: pass @@ -7,4 +15,59 @@ class GenerationConverter: class _MetricManager: - pass + def __init__(self, server_args: ServerArgs): + self.metrics_collector = TokenizerMetricsCollector( + labels={ + "model_name": server_args.served_model_name, + # TODO: Add lora name/path in the future, + }, + ) + + def handle_batch_output_metrics( + self, + recv_obj, + i: int, + state: "_MetricReqState", + finished: bool, + stream: Optional[bool], + ): + completion_tokens = ( + recv_obj.completion_tokens[i] + if getattr(recv_obj, "completion_tokens", None) + else 0 + ) + + if state.first_token_time is None: + state.first_token_time = time.time() + self.metrics_collector.observe_time_to_first_token( + state.first_token_time - state.created_time + ) + else: + if completion_tokens >= 2: + # Compute time_per_output_token for the streaming case + self.metrics_collector.observe_time_per_output_token( + (time.time() - state.first_token_time) / (completion_tokens - 1) + ) + + if state.finished: + self.metrics_collector.observe_one_finished_request( + recv_obj.prompt_tokens[i], completion_tokens + ) + self.metrics_collector.observe_e2e_request_latency( + time.time() - state.created_time + ) + # Compute time_per_output_token for the non-streaming case + if ( + hasattr(state.obj, "stream") + and not state.obj.stream + and completion_tokens >= 1 + ): + self.metrics_collector.observe_time_per_output_token( + (time.time() - state.created_time) / completion_tokens + ) + + +@dataclasses.dataclass +class _MetricReqState: + created_time: float + first_token_time: Optional[float] = None diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py index 2be2e532d07..b190025a9e5 100644 --- a/python/sglang/srt/managers/tokenizer_manager.py +++ b/python/sglang/srt/managers/tokenizer_manager.py @@ -96,10 +96,6 @@ class ReqState: event: asyncio.Event obj: Any - # For metrics - created_time: float - first_token_time: Optional[float] = None - # For streaming output last_output_offset: int = 0 @@ -217,12 +213,7 @@ def __init__( # Metrics if self.enable_metrics: - self.metrics_collector = TokenizerMetricsCollector( - labels={ - "model_name": self.server_args.served_model_name, - # TODO: Add lora name/path in the future, - }, - ) + TODO_moved self._result_dispatcher = TypeBasedDispatcher( [ @@ -886,40 +877,7 @@ def detokenize_top_logprobs_tokens( return ret def collect_metrics(self, state: ReqState, recv_obj: BatchStrOut, i: int): - completion_tokens = ( - recv_obj.completion_tokens[i] - if getattr(recv_obj, "completion_tokens", None) - else 0 - ) - - if state.first_token_time is None: - state.first_token_time = time.time() - self.metrics_collector.observe_time_to_first_token( - state.first_token_time - state.created_time - ) - else: - if completion_tokens >= 2: - # Compute time_per_output_token for the streaming case - self.metrics_collector.observe_time_per_output_token( - (time.time() - state.first_token_time) / (completion_tokens - 1) - ) - - if state.finished: - self.metrics_collector.observe_one_finished_request( - recv_obj.prompt_tokens[i], completion_tokens - ) - self.metrics_collector.observe_e2e_request_latency( - time.time() - state.created_time - ) - # Compute time_per_output_token for the non-streaming case - if ( - hasattr(state.obj, "stream") - and not state.obj.stream - and completion_tokens >= 1 - ): - self.metrics_collector.observe_time_per_output_token( - (time.time() - state.created_time) / completion_tokens - ) + TODO_moved def dump_requests(self, state: ReqState, out_dict: dict): self.dump_request_list.append( From 65b3a374710f64b9e82d4830cf30ab3d11bd29bc Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Sat, 25 Jan 2025 09:25:40 +0800 Subject: [PATCH 004/136] fix --- python/sglang/srt/managers/generation_manager.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/python/sglang/srt/managers/generation_manager.py b/python/sglang/srt/managers/generation_manager.py index 083e63c6e74..0fad03deaeb 100644 --- a/python/sglang/srt/managers/generation_manager.py +++ b/python/sglang/srt/managers/generation_manager.py @@ -49,7 +49,7 @@ def handle_batch_output_metrics( (time.time() - state.first_token_time) / (completion_tokens - 1) ) - if state.finished: + if finished: self.metrics_collector.observe_one_finished_request( recv_obj.prompt_tokens[i], completion_tokens ) @@ -57,11 +57,7 @@ def handle_batch_output_metrics( time.time() - state.created_time ) # Compute time_per_output_token for the non-streaming case - if ( - hasattr(state.obj, "stream") - and not state.obj.stream - and completion_tokens >= 1 - ): + if stream is not None and not stream and completion_tokens >= 1: self.metrics_collector.observe_time_per_output_token( (time.time() - state.created_time) / completion_tokens ) From 6ce52366741a7804ba201775546475a8210a887e Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Sat, 25 Jan 2025 09:26:19 +0800 Subject: [PATCH 005/136] mv _ReqState --- python/sglang/srt/managers/generation_manager.py | 16 +++++++++++++++- python/sglang/srt/managers/tokenizer_manager.py | 13 ------------- 2 files changed, 15 insertions(+), 14 deletions(-) diff --git a/python/sglang/srt/managers/generation_manager.py b/python/sglang/srt/managers/generation_manager.py index 0fad03deaeb..2e0715ead63 100644 --- a/python/sglang/srt/managers/generation_manager.py +++ b/python/sglang/srt/managers/generation_manager.py @@ -1,6 +1,7 @@ +import asyncio import dataclasses import time -from typing import Optional +from typing import Optional, List, Any from sglang.srt.metrics.collector import TokenizerMetricsCollector from sglang.srt.server_args import ServerArgs @@ -63,6 +64,19 @@ def handle_batch_output_metrics( ) +@dataclasses.dataclass +class _ReqState: + """Store the state a request.""" + + out_list: List + finished: bool + event: asyncio.Event + obj: Any + + # For streaming output + last_output_offset: int = 0 + + @dataclasses.dataclass class _MetricReqState: created_time: float diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py index b190025a9e5..17b37f0f4b3 100644 --- a/python/sglang/srt/managers/tokenizer_manager.py +++ b/python/sglang/srt/managers/tokenizer_manager.py @@ -87,19 +87,6 @@ logger = logging.getLogger(__name__) -@dataclasses.dataclass -class ReqState: - """Store the state a request.""" - - out_list: List - finished: bool - event: asyncio.Event - obj: Any - - # For streaming output - last_output_offset: int = 0 - - class TokenizerManager: """TokenizerManager is a process that tokenizes the text.""" From 7ca0a47d56f770433b457aa5e563fa3418648ea0 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Sat, 25 Jan 2025 09:27:16 +0800 Subject: [PATCH 006/136] mv GenerationConverter.init --- .../sglang/srt/managers/generation_manager.py | 41 ++++++++++++++++++- .../sglang/srt/managers/tokenizer_manager.py | 29 ------------- 2 files changed, 40 insertions(+), 30 deletions(-) diff --git a/python/sglang/srt/managers/generation_manager.py b/python/sglang/srt/managers/generation_manager.py index 2e0715ead63..b42e913148f 100644 --- a/python/sglang/srt/managers/generation_manager.py +++ b/python/sglang/srt/managers/generation_manager.py @@ -1,8 +1,11 @@ import asyncio import dataclasses +import os import time from typing import Optional, List, Any +from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer +from sglang.srt.managers.image_processor import get_dummy_image_processor, get_image_processor from sglang.srt.metrics.collector import TokenizerMetricsCollector from sglang.srt.server_args import ServerArgs @@ -12,7 +15,43 @@ class GenerationManager: class GenerationConverter: - pass + """Preprocessors and postprocessors for generation""" + + def __init__( + self, + server_args: ServerArgs, + ): + self.server_args = server_args + self.model_config = _compute_model_config(server_args) + + # Create image processor placeholder + self.image_processor = get_dummy_image_processor() + + # Create tokenizer + if server_args.skip_tokenizer_init: + self.tokenizer = self.processor = None + else: + if self.model_config.is_multimodal: + self.processor = get_processor( + server_args.tokenizer_path, + tokenizer_mode=server_args.tokenizer_mode, + trust_remote_code=server_args.trust_remote_code, + revision=server_args.revision, + ) + self.tokenizer = self.processor.tokenizer + os.environ["TOKENIZERS_PARALLELISM"] = "false" + + # We want to parallelize the image pre-processing so we create an executor for it + self.image_processor = get_image_processor( + self.model_config.hf_config, server_args, self.processor + ) + else: + self.tokenizer = get_tokenizer( + server_args.tokenizer_path, + tokenizer_mode=server_args.tokenizer_mode, + trust_remote_code=server_args.trust_remote_code, + revision=server_args.revision, + ) class _MetricManager: diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py index 17b37f0f4b3..efe119f7f99 100644 --- a/python/sglang/srt/managers/tokenizer_manager.py +++ b/python/sglang/srt/managers/tokenizer_manager.py @@ -129,35 +129,6 @@ def __init__( self.context_len = self.model_config.context_len self.image_token_id = self.model_config.image_token_id - # Create image processor placeholder - self.image_processor = get_dummy_image_processor() - - # Create tokenizer - if server_args.skip_tokenizer_init: - self.tokenizer = self.processor = None - else: - if self.model_config.is_multimodal: - self.processor = get_processor( - server_args.tokenizer_path, - tokenizer_mode=server_args.tokenizer_mode, - trust_remote_code=server_args.trust_remote_code, - revision=server_args.revision, - ) - self.tokenizer = self.processor.tokenizer - os.environ["TOKENIZERS_PARALLELISM"] = "false" - - # We want to parallelize the image pre-processing so we create an executor for it - self.image_processor = get_image_processor( - self.model_config.hf_config, server_args, self.processor - ) - else: - self.tokenizer = get_tokenizer( - server_args.tokenizer_path, - tokenizer_mode=server_args.tokenizer_mode, - trust_remote_code=server_args.trust_remote_code, - revision=server_args.revision, - ) - # Store states self.no_create_loop = False self.rid_to_state: Dict[str, ReqState] = {} From b88e450af046166570eb11e6d69a84a72a6fef90 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Sat, 25 Jan 2025 09:28:02 +0800 Subject: [PATCH 007/136] mv tokenize_request --- .../sglang/srt/managers/generation_manager.py | 101 +++++++++++++++++- .../sglang/srt/managers/tokenizer_manager.py | 96 ----------------- 2 files changed, 100 insertions(+), 97 deletions(-) diff --git a/python/sglang/srt/managers/generation_manager.py b/python/sglang/srt/managers/generation_manager.py index b42e913148f..5697e868e6c 100644 --- a/python/sglang/srt/managers/generation_manager.py +++ b/python/sglang/srt/managers/generation_manager.py @@ -2,11 +2,14 @@ import dataclasses import os import time -from typing import Optional, List, Any +from typing import Optional, List, Any, Union, Dict from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer from sglang.srt.managers.image_processor import get_dummy_image_processor, get_image_processor +from sglang.srt.managers.io_struct import GenerateReqInput, EmbeddingReqInput, SessionParams, TokenizedGenerateReqInput, \ + TokenizedEmbeddingReqInput from sglang.srt.metrics.collector import TokenizerMetricsCollector +from sglang.srt.sampling.sampling_params import SamplingParams from sglang.srt.server_args import ServerArgs @@ -53,6 +56,102 @@ def __init__( revision=server_args.revision, ) + async def tokenize_request( + self, + obj: Union[GenerateReqInput, EmbeddingReqInput], + ): + """Tokenize one request.""" + # Tokenize + input_embeds = None + input_text = obj.text + if obj.input_embeds is not None: + if not self.server_args.disable_radix_cache: + raise ValueError( + "input_embeds is provided while disable_radix_cache is False. " + "Please add `--disable-radix-cache` when you launch the server " + "if you want to use input_embeds as inputs." + ) + input_embeds = obj.input_embeds + input_ids = obj.input_ids + elif obj.input_ids is not None: + input_ids = obj.input_ids + else: + if self.tokenizer is None: + raise ValueError( + "The engine initialized with skip_tokenizer_init=True cannot " + "accept text prompts. Please provide input_ids or re-initialize " + "the engine with skip_tokenizer_init=False." + ) + input_ids = self.tokenizer.encode(input_text) + + if self.is_generation: + # TODO: also support getting embeddings for multimodal models + image_inputs: Dict = await self.image_processor.process_images_async( + obj.image_data, input_text or input_ids, obj, self.max_req_input_len + ) + if image_inputs and "input_ids" in image_inputs: + input_ids = image_inputs["input_ids"] + return_logprob = obj.return_logprob + logprob_start_len = obj.logprob_start_len + top_logprobs_num = obj.top_logprobs_num + session_params = ( + SessionParams(**obj.session_params) if obj.session_params else None + ) + + input_token_num = len(input_ids) if input_ids is not None else 0 + if input_token_num >= self.context_len: + raise ValueError( + f"The input ({input_token_num} tokens) is longer than the " + f"model's context length ({self.context_len} tokens)." + ) + + if ( + obj.sampling_params.get("max_new_tokens") is not None + and obj.sampling_params.get("max_new_tokens") + input_token_num + >= self.context_len + ): + raise ValueError( + f"Requested token count exceeds the model's maximum context length " + f"of {self.context_len} tokens. You requested a total of " + f"{obj.sampling_params.get('max_new_tokens') + input_token_num} " + f"tokens: {input_token_num} tokens from the input messages and " + f"{obj.sampling_params.get('max_new_tokens')} tokens for the " + f"completion. Please reduce the number of tokens in the input " + f"messages or the completion to fit within the limit." + ) + + # Parse sampling parameters + sampling_params = SamplingParams(**obj.sampling_params) + sampling_params.normalize(self.tokenizer) + sampling_params.verify() + + # Build return object + if isinstance(obj, GenerateReqInput): + tokenized_obj = TokenizedGenerateReqInput( + obj.rid, + input_text, + input_ids, + image_inputs, + sampling_params, + return_logprob, + logprob_start_len, + top_logprobs_num, + obj.stream, + lora_path=obj.lora_path, + input_embeds=input_embeds, + session_params=session_params, + custom_logit_processor=obj.custom_logit_processor, + ) + elif isinstance(obj, EmbeddingReqInput): + tokenized_obj = TokenizedEmbeddingReqInput( + obj.rid, + input_text, + input_ids, + sampling_params, + ) + + return tokenized_obj + class _MetricManager: def __init__(self, server_args: ServerArgs): diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py index efe119f7f99..5d7a54fc80d 100644 --- a/python/sglang/srt/managers/tokenizer_manager.py +++ b/python/sglang/srt/managers/tokenizer_manager.py @@ -247,102 +247,6 @@ async def generate_request( ): yield response - async def _tokenize_one_request( - self, - obj: Union[GenerateReqInput, EmbeddingReqInput], - ): - """Tokenize one request.""" - # Tokenize - input_embeds = None - input_text = obj.text - if obj.input_embeds is not None: - if not self.server_args.disable_radix_cache: - raise ValueError( - "input_embeds is provided while disable_radix_cache is False. " - "Please add `--disable-radix-cache` when you launch the server " - "if you want to use input_embeds as inputs." - ) - input_embeds = obj.input_embeds - input_ids = obj.input_ids - elif obj.input_ids is not None: - input_ids = obj.input_ids - else: - if self.tokenizer is None: - raise ValueError( - "The engine initialized with skip_tokenizer_init=True cannot " - "accept text prompts. Please provide input_ids or re-initialize " - "the engine with skip_tokenizer_init=False." - ) - input_ids = self.tokenizer.encode(input_text) - - if self.is_generation: - # TODO: also support getting embeddings for multimodal models - image_inputs: Dict = await self.image_processor.process_images_async( - obj.image_data, input_text or input_ids, obj, self.max_req_input_len - ) - if image_inputs and "input_ids" in image_inputs: - input_ids = image_inputs["input_ids"] - return_logprob = obj.return_logprob - logprob_start_len = obj.logprob_start_len - top_logprobs_num = obj.top_logprobs_num - session_params = ( - SessionParams(**obj.session_params) if obj.session_params else None - ) - - input_token_num = len(input_ids) if input_ids is not None else 0 - if input_token_num >= self.context_len: - raise ValueError( - f"The input ({input_token_num} tokens) is longer than the " - f"model's context length ({self.context_len} tokens)." - ) - - if ( - obj.sampling_params.get("max_new_tokens") is not None - and obj.sampling_params.get("max_new_tokens") + input_token_num - >= self.context_len - ): - raise ValueError( - f"Requested token count exceeds the model's maximum context length " - f"of {self.context_len} tokens. You requested a total of " - f"{obj.sampling_params.get('max_new_tokens') + input_token_num} " - f"tokens: {input_token_num} tokens from the input messages and " - f"{obj.sampling_params.get('max_new_tokens')} tokens for the " - f"completion. Please reduce the number of tokens in the input " - f"messages or the completion to fit within the limit." - ) - - # Parse sampling parameters - sampling_params = SamplingParams(**obj.sampling_params) - sampling_params.normalize(self.tokenizer) - sampling_params.verify() - - # Build return object - if isinstance(obj, GenerateReqInput): - tokenized_obj = TokenizedGenerateReqInput( - obj.rid, - input_text, - input_ids, - image_inputs, - sampling_params, - return_logprob, - logprob_start_len, - top_logprobs_num, - obj.stream, - lora_path=obj.lora_path, - input_embeds=input_embeds, - session_params=session_params, - custom_logit_processor=obj.custom_logit_processor, - ) - elif isinstance(obj, EmbeddingReqInput): - tokenized_obj = TokenizedEmbeddingReqInput( - obj.rid, - input_text, - input_ids, - sampling_params, - ) - - return tokenized_obj - def _send_one_request( self, obj: Union[GenerateReqInput, EmbeddingReqInput], From 3b8ed7bcbb4fa3f58e2676fb46c8b025503dcfbf Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Sat, 25 Jan 2025 09:28:21 +0800 Subject: [PATCH 008/136] simp branch --- python/sglang/srt/managers/generation_manager.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/python/sglang/srt/managers/generation_manager.py b/python/sglang/srt/managers/generation_manager.py index 5697e868e6c..43a22ce21e8 100644 --- a/python/sglang/srt/managers/generation_manager.py +++ b/python/sglang/srt/managers/generation_manager.py @@ -127,7 +127,7 @@ async def tokenize_request( # Build return object if isinstance(obj, GenerateReqInput): - tokenized_obj = TokenizedGenerateReqInput( + return TokenizedGenerateReqInput( obj.rid, input_text, input_ids, @@ -143,14 +143,14 @@ async def tokenize_request( custom_logit_processor=obj.custom_logit_processor, ) elif isinstance(obj, EmbeddingReqInput): - tokenized_obj = TokenizedEmbeddingReqInput( + return TokenizedEmbeddingReqInput( obj.rid, input_text, input_ids, sampling_params, ) - - return tokenized_obj + else: + raise NotImplementedError class _MetricManager: From 2f47f922cd5167b302deaaed0887c4728bbded69 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Sat, 25 Jan 2025 09:29:00 +0800 Subject: [PATCH 009/136] tokenize_requests --- python/sglang/srt/managers/generation_manager.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/python/sglang/srt/managers/generation_manager.py b/python/sglang/srt/managers/generation_manager.py index 43a22ce21e8..888a8da088f 100644 --- a/python/sglang/srt/managers/generation_manager.py +++ b/python/sglang/srt/managers/generation_manager.py @@ -152,6 +152,16 @@ async def tokenize_request( else: raise NotImplementedError + def tokenize_requests( + self, + obj: Union[GenerateReqInput, EmbeddingReqInput], + ) -> List[Union[TokenizedGenerateReqInput, TokenizedEmbeddingReqInput]]: + objs = [obj[i] for i in range(obj.batch_size)] + loop = asyncio.get_event_loop() + return loop.run_until_complete( + asyncio.gather(*(self.tokenize_request(obj) for obj in objs)) + ) + class _MetricManager: def __init__(self, server_args: ServerArgs): From e21a05ee478090e7898127c70bf56c23af3edd0e Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Sat, 25 Jan 2025 09:30:14 +0800 Subject: [PATCH 010/136] mv postprocess_response --- .../sglang/srt/managers/generation_manager.py | 48 ++++++++++++++++++- .../sglang/srt/managers/tokenizer_manager.py | 39 +-------------- 2 files changed, 48 insertions(+), 39 deletions(-) diff --git a/python/sglang/srt/managers/generation_manager.py b/python/sglang/srt/managers/generation_manager.py index 888a8da088f..da1bde200aa 100644 --- a/python/sglang/srt/managers/generation_manager.py +++ b/python/sglang/srt/managers/generation_manager.py @@ -7,7 +7,7 @@ from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer from sglang.srt.managers.image_processor import get_dummy_image_processor, get_image_processor from sglang.srt.managers.io_struct import GenerateReqInput, EmbeddingReqInput, SessionParams, TokenizedGenerateReqInput, \ - TokenizedEmbeddingReqInput + TokenizedEmbeddingReqInput, BatchStrOut, BatchEmbeddingOut, BatchTokenIDOut from sglang.srt.metrics.collector import TokenizerMetricsCollector from sglang.srt.sampling.sampling_params import SamplingParams from sglang.srt.server_args import ServerArgs @@ -162,6 +162,52 @@ def tokenize_requests( asyncio.gather(*(self.tokenize_request(obj) for obj in objs)) ) + def postprocess_response( + self, + recv_obj: Union[BatchStrOut, BatchEmbeddingOut, BatchTokenIDOut], + index: int, + req_obj: Union[GenerateReqInput, EmbeddingReqInput], + ) -> Dict[str, Any]: + meta_info = { + "id": rid, + "finish_reason": recv_obj.finished_reasons[i], + "prompt_tokens": recv_obj.prompt_tokens[i], + } + + if getattr(state.obj, "return_logprob", False): + self.convert_logprob_style( + meta_info, + state.obj.top_logprobs_num, + state.obj.return_text_in_logprobs, + recv_obj, + i, + ) + + if not isinstance(recv_obj, BatchEmbeddingOut): + meta_info.update( + { + "completion_tokens": recv_obj.completion_tokens[i], + "cached_tokens": recv_obj.cached_tokens[i], + } + ) + + if isinstance(recv_obj, BatchStrOut): + out_dict = { + "text": recv_obj.output_strs[i], + "meta_info": meta_info, + } + elif isinstance(recv_obj, BatchTokenIDOut): + out_dict = { + "token_ids": recv_obj.output_ids[i], + "meta_info": meta_info, + } + else: + assert isinstance(recv_obj, BatchEmbeddingOut) + out_dict = { + "embedding": recv_obj.embeddings[i], + "meta_info": meta_info, + } + class _MetricManager: def __init__(self, server_args: ServerArgs): diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py index 5d7a54fc80d..ff495ce42c3 100644 --- a/python/sglang/srt/managers/tokenizer_manager.py +++ b/python/sglang/srt/managers/tokenizer_manager.py @@ -623,45 +623,8 @@ def _handle_batch_output( if state is None: continue - meta_info = { - "id": rid, - "finish_reason": recv_obj.finished_reasons[i], - "prompt_tokens": recv_obj.prompt_tokens[i], - } - - if getattr(state.obj, "return_logprob", False): - self.convert_logprob_style( - meta_info, - state.obj.top_logprobs_num, - state.obj.return_text_in_logprobs, - recv_obj, - i, - ) - - if not isinstance(recv_obj, BatchEmbeddingOut): - meta_info.update( - { - "completion_tokens": recv_obj.completion_tokens[i], - "cached_tokens": recv_obj.cached_tokens[i], - } - ) + TODO_moved_postprocesse_response - if isinstance(recv_obj, BatchStrOut): - out_dict = { - "text": recv_obj.output_strs[i], - "meta_info": meta_info, - } - elif isinstance(recv_obj, BatchTokenIDOut): - out_dict = { - "token_ids": recv_obj.output_ids[i], - "meta_info": meta_info, - } - else: - assert isinstance(recv_obj, BatchEmbeddingOut) - out_dict = { - "embedding": recv_obj.embeddings[i], - "meta_info": meta_info, - } state.out_list.append(out_dict) state.finished = recv_obj.finished_reasons[i] is not None state.event.set() From ab5d79a6a850b117304c767c8789b33fdd580e4a Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Sat, 25 Jan 2025 09:30:36 +0800 Subject: [PATCH 011/136] simp code --- python/sglang/srt/managers/generation_manager.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/python/sglang/srt/managers/generation_manager.py b/python/sglang/srt/managers/generation_manager.py index da1bde200aa..83cc8737ec2 100644 --- a/python/sglang/srt/managers/generation_manager.py +++ b/python/sglang/srt/managers/generation_manager.py @@ -192,21 +192,22 @@ def postprocess_response( ) if isinstance(recv_obj, BatchStrOut): - out_dict = { + return { "text": recv_obj.output_strs[i], "meta_info": meta_info, } elif isinstance(recv_obj, BatchTokenIDOut): - out_dict = { + return { "token_ids": recv_obj.output_ids[i], "meta_info": meta_info, } - else: - assert isinstance(recv_obj, BatchEmbeddingOut) - out_dict = { + elif isinstance(recv_obj, BatchEmbeddingOut): + return { "embedding": recv_obj.embeddings[i], "meta_info": meta_info, } + else: + raise NotImplementedError class _MetricManager: From 053c8f47999eb7792112028e15d3da645b4dd23e Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Sat, 25 Jan 2025 09:31:14 +0800 Subject: [PATCH 012/136] extract _compute_meta_info --- .../sglang/srt/managers/generation_manager.py | 39 +++++++++++-------- 1 file changed, 22 insertions(+), 17 deletions(-) diff --git a/python/sglang/srt/managers/generation_manager.py b/python/sglang/srt/managers/generation_manager.py index 83cc8737ec2..1dd6629b776 100644 --- a/python/sglang/srt/managers/generation_manager.py +++ b/python/sglang/srt/managers/generation_manager.py @@ -168,6 +168,27 @@ def postprocess_response( index: int, req_obj: Union[GenerateReqInput, EmbeddingReqInput], ) -> Dict[str, Any]: + meta_info = self._compute_meta_info(index, recv_obj, req_obj) + + if isinstance(recv_obj, BatchStrOut): + return { + "text": recv_obj.output_strs[i], + "meta_info": meta_info, + } + elif isinstance(recv_obj, BatchTokenIDOut): + return { + "token_ids": recv_obj.output_ids[i], + "meta_info": meta_info, + } + elif isinstance(recv_obj, BatchEmbeddingOut): + return { + "embedding": recv_obj.embeddings[i], + "meta_info": meta_info, + } + else: + raise NotImplementedError + + def _compute_meta_info(self, index, recv_obj, req_obj): meta_info = { "id": rid, "finish_reason": recv_obj.finished_reasons[i], @@ -191,23 +212,7 @@ def postprocess_response( } ) - if isinstance(recv_obj, BatchStrOut): - return { - "text": recv_obj.output_strs[i], - "meta_info": meta_info, - } - elif isinstance(recv_obj, BatchTokenIDOut): - return { - "token_ids": recv_obj.output_ids[i], - "meta_info": meta_info, - } - elif isinstance(recv_obj, BatchEmbeddingOut): - return { - "embedding": recv_obj.embeddings[i], - "meta_info": meta_info, - } - else: - raise NotImplementedError + return meta_info class _MetricManager: From 02c451cc0a2fdab9d3ecad1b8f93d17b2d485961 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Sat, 25 Jan 2025 09:31:39 +0800 Subject: [PATCH 013/136] mv convert_logprob_style etc --- .../sglang/srt/managers/generation_manager.py | 67 +++++++++++++++++++ .../sglang/srt/managers/tokenizer_manager.py | 67 ------------------- 2 files changed, 67 insertions(+), 67 deletions(-) diff --git a/python/sglang/srt/managers/generation_manager.py b/python/sglang/srt/managers/generation_manager.py index 1dd6629b776..fffb005326b 100644 --- a/python/sglang/srt/managers/generation_manager.py +++ b/python/sglang/srt/managers/generation_manager.py @@ -214,6 +214,73 @@ def _compute_meta_info(self, index, recv_obj, req_obj): return meta_info + def convert_logprob_style( + self, + meta_info: dict, + top_logprobs_num: int, + return_text_in_logprobs: bool, + recv_obj: BatchStrOut, + recv_obj_index: int, + ): + meta_info["input_token_logprobs"] = self.detokenize_logprob_tokens( + recv_obj.input_token_logprobs_val[recv_obj_index], + recv_obj.input_token_logprobs_idx[recv_obj_index], + return_text_in_logprobs, + ) + meta_info["output_token_logprobs"] = self.detokenize_logprob_tokens( + recv_obj.output_token_logprobs_val[recv_obj_index], + recv_obj.output_token_logprobs_idx[recv_obj_index], + return_text_in_logprobs, + ) + + if top_logprobs_num > 0: + meta_info["input_top_logprobs"] = self.detokenize_top_logprobs_tokens( + recv_obj.input_top_logprobs_val[recv_obj_index], + recv_obj.input_top_logprobs_idx[recv_obj_index], + return_text_in_logprobs, + ) + meta_info["output_top_logprobs"] = self.detokenize_top_logprobs_tokens( + recv_obj.output_top_logprobs_val[recv_obj_index], + recv_obj.output_top_logprobs_idx[recv_obj_index], + return_text_in_logprobs, + ) + + def detokenize_logprob_tokens( + self, + token_logprobs_val: List[float], + token_logprobs_idx: List[int], + decode_to_text: bool, + ): + if not decode_to_text: + return [ + (logprob, token_id, None) + for logprob, token_id in zip(token_logprobs_val, token_logprobs_idx) + ] + else: + assert self.tokenizer is not None + token_texts = self.tokenizer.batch_decode(token_logprobs_idx) + return list(zip(token_logprobs_val, token_logprobs_idx, token_texts)) + + def detokenize_top_logprobs_tokens( + self, + token_logprobs_val: List[float], + token_logprobs_idx: List[int], + decode_to_text: bool, + ): + # TODO: The current implementation only batches the detokenization for top-k tokens per single position. + # We should batch all top-k tokens in all positions. + ret = [] + for i in range(len(token_logprobs_val)): + if token_logprobs_val[i]: + ret.append( + self.detokenize_logprob_tokens( + token_logprobs_val[i], token_logprobs_idx[i], decode_to_text + ) + ) + else: + ret.append(None) + return ret + class _MetricManager: def __init__(self, server_args: ServerArgs): diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py index ff495ce42c3..91da78092ad 100644 --- a/python/sglang/srt/managers/tokenizer_manager.py +++ b/python/sglang/srt/managers/tokenizer_manager.py @@ -634,73 +634,6 @@ def _handle_batch_output( if self.dump_requests_folder and state.finished and state.obj.log_metrics: self.dump_requests(state, out_dict) - def convert_logprob_style( - self, - meta_info: dict, - top_logprobs_num: int, - return_text_in_logprobs: bool, - recv_obj: BatchStrOut, - recv_obj_index: int, - ): - meta_info["input_token_logprobs"] = self.detokenize_logprob_tokens( - recv_obj.input_token_logprobs_val[recv_obj_index], - recv_obj.input_token_logprobs_idx[recv_obj_index], - return_text_in_logprobs, - ) - meta_info["output_token_logprobs"] = self.detokenize_logprob_tokens( - recv_obj.output_token_logprobs_val[recv_obj_index], - recv_obj.output_token_logprobs_idx[recv_obj_index], - return_text_in_logprobs, - ) - - if top_logprobs_num > 0: - meta_info["input_top_logprobs"] = self.detokenize_top_logprobs_tokens( - recv_obj.input_top_logprobs_val[recv_obj_index], - recv_obj.input_top_logprobs_idx[recv_obj_index], - return_text_in_logprobs, - ) - meta_info["output_top_logprobs"] = self.detokenize_top_logprobs_tokens( - recv_obj.output_top_logprobs_val[recv_obj_index], - recv_obj.output_top_logprobs_idx[recv_obj_index], - return_text_in_logprobs, - ) - - def detokenize_logprob_tokens( - self, - token_logprobs_val: List[float], - token_logprobs_idx: List[int], - decode_to_text: bool, - ): - if not decode_to_text: - return [ - (logprob, token_id, None) - for logprob, token_id in zip(token_logprobs_val, token_logprobs_idx) - ] - else: - assert self.tokenizer is not None - token_texts = self.tokenizer.batch_decode(token_logprobs_idx) - return list(zip(token_logprobs_val, token_logprobs_idx, token_texts)) - - def detokenize_top_logprobs_tokens( - self, - token_logprobs_val: List[float], - token_logprobs_idx: List[int], - decode_to_text: bool, - ): - # TODO: The current implementation only batches the detokenization for top-k tokens per single position. - # We should batch all top-k tokens in all positions. - ret = [] - for i in range(len(token_logprobs_val)): - if token_logprobs_val[i]: - ret.append( - self.detokenize_logprob_tokens( - token_logprobs_val[i], token_logprobs_idx[i], decode_to_text - ) - ) - else: - ret.append(None) - return ret - def collect_metrics(self, state: ReqState, recv_obj: BatchStrOut, i: int): TODO_moved From ccd5e8a8aeed81e94f2d898c7beabaa2ba177188 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Sat, 25 Jan 2025 09:31:53 +0800 Subject: [PATCH 014/136] make private --- .../sglang/srt/managers/generation_manager.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/python/sglang/srt/managers/generation_manager.py b/python/sglang/srt/managers/generation_manager.py index fffb005326b..755df688bfa 100644 --- a/python/sglang/srt/managers/generation_manager.py +++ b/python/sglang/srt/managers/generation_manager.py @@ -196,7 +196,7 @@ def _compute_meta_info(self, index, recv_obj, req_obj): } if getattr(state.obj, "return_logprob", False): - self.convert_logprob_style( + self._convert_logprob_style( meta_info, state.obj.top_logprobs_num, state.obj.return_text_in_logprobs, @@ -214,7 +214,7 @@ def _compute_meta_info(self, index, recv_obj, req_obj): return meta_info - def convert_logprob_style( + def _convert_logprob_style( self, meta_info: dict, top_logprobs_num: int, @@ -222,30 +222,30 @@ def convert_logprob_style( recv_obj: BatchStrOut, recv_obj_index: int, ): - meta_info["input_token_logprobs"] = self.detokenize_logprob_tokens( + meta_info["input_token_logprobs"] = self._detokenize_logprob_tokens( recv_obj.input_token_logprobs_val[recv_obj_index], recv_obj.input_token_logprobs_idx[recv_obj_index], return_text_in_logprobs, ) - meta_info["output_token_logprobs"] = self.detokenize_logprob_tokens( + meta_info["output_token_logprobs"] = self._detokenize_logprob_tokens( recv_obj.output_token_logprobs_val[recv_obj_index], recv_obj.output_token_logprobs_idx[recv_obj_index], return_text_in_logprobs, ) if top_logprobs_num > 0: - meta_info["input_top_logprobs"] = self.detokenize_top_logprobs_tokens( + meta_info["input_top_logprobs"] = self._detokenize_top_logprobs_tokens( recv_obj.input_top_logprobs_val[recv_obj_index], recv_obj.input_top_logprobs_idx[recv_obj_index], return_text_in_logprobs, ) - meta_info["output_top_logprobs"] = self.detokenize_top_logprobs_tokens( + meta_info["output_top_logprobs"] = self._detokenize_top_logprobs_tokens( recv_obj.output_top_logprobs_val[recv_obj_index], recv_obj.output_top_logprobs_idx[recv_obj_index], return_text_in_logprobs, ) - def detokenize_logprob_tokens( + def _detokenize_logprob_tokens( self, token_logprobs_val: List[float], token_logprobs_idx: List[int], @@ -261,7 +261,7 @@ def detokenize_logprob_tokens( token_texts = self.tokenizer.batch_decode(token_logprobs_idx) return list(zip(token_logprobs_val, token_logprobs_idx, token_texts)) - def detokenize_top_logprobs_tokens( + def _detokenize_top_logprobs_tokens( self, token_logprobs_val: List[float], token_logprobs_idx: List[int], @@ -273,7 +273,7 @@ def detokenize_top_logprobs_tokens( for i in range(len(token_logprobs_val)): if token_logprobs_val[i]: ret.append( - self.detokenize_logprob_tokens( + self._detokenize_logprob_tokens( token_logprobs_val[i], token_logprobs_idx[i], decode_to_text ) ) From ecf5e2119abbb055753cb89301524508b702b418 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Sat, 25 Jan 2025 09:33:05 +0800 Subject: [PATCH 015/136] mv GenerationManager.init --- .../sglang/srt/managers/generation_manager.py | 23 +++++++++++++++++-- .../sglang/srt/managers/tokenizer_manager.py | 5 ---- 2 files changed, 21 insertions(+), 7 deletions(-) diff --git a/python/sglang/srt/managers/generation_manager.py b/python/sglang/srt/managers/generation_manager.py index 755df688bfa..90bd2be017f 100644 --- a/python/sglang/srt/managers/generation_manager.py +++ b/python/sglang/srt/managers/generation_manager.py @@ -2,7 +2,7 @@ import dataclasses import os import time -from typing import Optional, List, Any, Union, Dict +from typing import Optional, List, Any, Union, Dict, Callable from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer from sglang.srt.managers.image_processor import get_dummy_image_processor, get_image_processor @@ -14,7 +14,26 @@ class GenerationManager: - pass + def __init__( + self, + server_args: ServerArgs, + on_request: Callable, + ): + self.server_args = server_args + self.on_request = on_request + + self.model_config = _compute_model_config(server_args) + self._generation_converter = GenerationConverter(server_args=server_args) + + self.rid_to_state: Dict[str, _ReqState] = {} + + # Metrics + if server_args.enable_metrics: + self._metric_manager = _MetricManager( + server_args=server_args, + ) + else: + self._metric_manager = None class GenerationConverter: diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py index 91da78092ad..631d38d6012 100644 --- a/python/sglang/srt/managers/tokenizer_manager.py +++ b/python/sglang/srt/managers/tokenizer_manager.py @@ -131,7 +131,6 @@ def __init__( # Store states self.no_create_loop = False - self.rid_to_state: Dict[str, ReqState] = {} self.dump_requests_folder = "" # By default do not dump self.dump_requests_threshold = 1000 self.dump_request_list: List[Tuple] = [] @@ -169,10 +168,6 @@ def __init__( # Set after scheduler is initialized self.max_req_input_len = None - # Metrics - if self.enable_metrics: - TODO_moved - self._result_dispatcher = TypeBasedDispatcher( [ ( From 818f8cd79c19767fafe38602983b52e45a91c873 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Sat, 25 Jan 2025 09:34:52 +0800 Subject: [PATCH 016/136] mv GenerationManager body --- .../sglang/srt/managers/generation_manager.py | 203 ++++++++++++++++++ .../sglang/srt/managers/tokenizer_manager.py | 193 +---------------- 2 files changed, 205 insertions(+), 191 deletions(-) diff --git a/python/sglang/srt/managers/generation_manager.py b/python/sglang/srt/managers/generation_manager.py index 90bd2be017f..f559b815295 100644 --- a/python/sglang/srt/managers/generation_manager.py +++ b/python/sglang/srt/managers/generation_manager.py @@ -35,6 +35,209 @@ def __init__( else: self._metric_manager = None + async def generate( + self, + obj: Union[GenerateReqInput, EmbeddingReqInput], + request: Optional[fastapi.Request] = None, + ): + created_time = time.time() + + self.auto_create_handle_loop() + + if isinstance(obj, EmbeddingReqInput) and self.is_generation: + raise ValueError( + "This model does not appear to be an embedding model by default. " + "Please add `--is-embedding` when launching the server or try another model." + ) + + obj.normalize_batch_and_arguments() + + if self.log_requests: + max_length = 2048 if self.log_requests_level == 0 else 1 << 30 + logger.info( + f"Receive: obj={dataclass_to_string_truncated(obj, max_length)}" + ) + + async with self.model_update_lock.reader_lock: + is_single = obj.is_single + if is_single: + tokenized_obj = await self._tokenize_one_request(obj) + self._send_one_request(obj, tokenized_obj, created_time) + async for response in self._wait_one_response(obj, request): + yield response + else: + async for response in self._handle_batch_request( + obj, request, created_time + ): + yield response + + def _send_one_request( + self, + obj: Union[GenerateReqInput, EmbeddingReqInput], + tokenized_obj: Union[TokenizedGenerateReqInput, TokenizedEmbeddingReqInput], + created_time: Optional[float] = None, + ): + event = asyncio.Event() + state = ReqState([], False, event, obj, created_time=created_time) + self.rid_to_state[obj.rid] = state + self.send_to_scheduler.send_pyobj(tokenized_obj) + + async def _wait_one_response( + self, + obj: Union[GenerateReqInput, EmbeddingReqInput], + request: Optional[fastapi.Request] = None, + ): + """Wait for the response of one request.""" + state = self.rid_to_state[obj.rid] + + while True: + try: + await asyncio.wait_for(state.event.wait(), timeout=4) + except asyncio.TimeoutError: + if request is not None and await request.is_disconnected(): + self.abort_request(obj.rid) + raise ValueError(f"Abort request {obj.rid}") + continue + + out = state.out_list[-1] + + state.out_list = [] + if state.finished: + if self.log_requests: + max_length = 2048 if self.log_requests_level == 0 else 1 << 30 + msg = f"Finish: obj={dataclass_to_string_truncated(obj, max_length)}, out={dataclass_to_string_truncated(out, max_length)}" + logger.info(msg) + del self.rid_to_state[obj.rid] + + # Check if this was an abort/error created by scheduler + if isinstance(out["meta_info"].get("finish_reason"), dict): + finish_reason = out["meta_info"]["finish_reason"] + if ( + finish_reason.get("type") == "abort" + and finish_reason.get("status_code") == HTTPStatus.BAD_REQUEST + ): + raise ValueError(finish_reason["message"]) + + yield out + break + + state.event.clear() + + if obj.stream: + yield out + else: + if request is not None and await request.is_disconnected(): + self.abort_request(obj.rid) + raise ValueError(f"Abort request {obj.rid}") + + async def _handle_batch_request( + self, + obj: Union[GenerateReqInput, EmbeddingReqInput], + request: Optional[fastapi.Request] = None, + created_time: Optional[float] = None, + ): + batch_size = obj.batch_size + + generators = [] + rids = [] + if getattr(obj, "parallel_sample_num", 1) == 1: + # Send all requests + for i in range(batch_size): + tmp_obj = obj[i] + tokenized_obj = await self._tokenize_one_request(tmp_obj) + self._send_one_request(tmp_obj, tokenized_obj, created_time) + generators.append(self._wait_one_response(tmp_obj, request)) + rids.append(tmp_obj.rid) + else: + # FIXME: When using batch and parallel_sample_num together, the perf is not optimal. + if batch_size > 128: + logger.warning( + "Sending a single large batch with parallel sampling (n > 1) has not been well optimized. " + "The performance might be better if you just duplicate the requests n times or use " + "many threads to send them one by one with parallel sampling (n > 1)." + ) + + # Tokenize all requests + objs = [obj[i] for i in range(batch_size)] + tokenized_objs = await asyncio.gather( + *(self._tokenize_one_request(obj) for obj in objs) + ) + + # Cache the common prefix for parallel sampling + for i in range(batch_size): + tmp_obj = copy.copy(objs[i]) + tokenized_obj = copy.copy(tokenized_objs[i]) + tokenized_obj.rid = tmp_obj.regenerate_rid() + tokenized_obj.sampling_params = copy.copy(tokenized_obj.sampling_params) + tokenized_obj.sampling_params.max_new_tokens = 0 + tokenized_obj.stream = False + self._send_one_request(tmp_obj, tokenized_obj, created_time) + await self._wait_one_response(tmp_obj, request).__anext__() + + # Expand requests, assign new rids for them, and send them + for i in range(batch_size): + for _ in range(obj.parallel_sample_num): + tmp_obj = copy.copy(objs[i]) + tokenized_obj = copy.copy(tokenized_objs[i]) + tokenized_obj.rid = tmp_obj.regenerate_rid() + self._send_one_request(tmp_obj, tokenized_obj, created_time) + generators.append(self._wait_one_response(tmp_obj, request)) + rids.append(tmp_obj.rid) + + # Wait for all requests + is_stream = hasattr(obj, "stream") and obj.stream + if not is_stream: + outputs = await asyncio.gather(*(gen.__anext__() for gen in generators)) + yield outputs + else: + rid_to_index = {rid: i for i, rid in enumerate(rids)} + task_map = {asyncio.create_task(gen.__anext__()): gen for gen in generators} + while task_map: + done, _ = await asyncio.wait( + task_map.keys(), return_when=asyncio.FIRST_COMPLETED + ) + + for task in done: + gen = task_map.pop(task) + try: + result = task.result() + result["index"] = rid_to_index[result["meta_info"]["id"]] + yield result + new_task = asyncio.create_task(gen.__anext__()) + task_map[new_task] = gen + except StopAsyncIteration: + pass + + def _handle_batch_output( + self, recv_obj: Union[BatchStrOut, BatchEmbeddingOut, BatchTokenIDOut] + ): + for i, rid in enumerate(recv_obj.rids): + state = self.rid_to_state.get(rid, None) + if state is None: + continue + + TODO_moved_postprocesse_response + + state.out_list.append(out_dict) + state.finished = recv_obj.finished_reasons[i] is not None + state.event.set() + + if self.enable_metrics and state.obj.log_metrics: + self.collect_metrics(state, recv_obj, i) + if self.dump_requests_folder and state.finished and state.obj.log_metrics: + self.dump_requests(state, out_dict) + + def abort_request(self, rid: str): + if rid not in self.rid_to_state: + return + del self.rid_to_state[rid] + req = AbortReq(rid) + self.send_to_scheduler.send_pyobj(req) + + @property + def tokenizer(self): + return self._generation_converter.tokenizer + class GenerationConverter: """Preprocessors and postprocessors for generation""" diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py index 631d38d6012..1517254af37 100644 --- a/python/sglang/srt/managers/tokenizer_manager.py +++ b/python/sglang/srt/managers/tokenizer_manager.py @@ -211,184 +211,14 @@ async def generate_request( obj: Union[GenerateReqInput, EmbeddingReqInput], request: Optional[fastapi.Request] = None, ): - created_time = time.time() - - self.auto_create_handle_loop() - - if isinstance(obj, EmbeddingReqInput) and self.is_generation: - raise ValueError( - "This model does not appear to be an embedding model by default. " - "Please add `--is-embedding` when launching the server or try another model." - ) - - obj.normalize_batch_and_arguments() - - if self.log_requests: - max_length = 2048 if self.log_requests_level == 0 else 1 << 30 - logger.info( - f"Receive: obj={dataclass_to_string_truncated(obj, max_length)}" - ) - - async with self.model_update_lock.reader_lock: - is_single = obj.is_single - if is_single: - tokenized_obj = await self._tokenize_one_request(obj) - self._send_one_request(obj, tokenized_obj, created_time) - async for response in self._wait_one_response(obj, request): - yield response - else: - async for response in self._handle_batch_request( - obj, request, created_time - ): - yield response - - def _send_one_request( - self, - obj: Union[GenerateReqInput, EmbeddingReqInput], - tokenized_obj: Union[TokenizedGenerateReqInput, TokenizedEmbeddingReqInput], - created_time: Optional[float] = None, - ): - event = asyncio.Event() - state = ReqState([], False, event, obj, created_time=created_time) - self.rid_to_state[obj.rid] = state - self.send_to_scheduler.send_pyobj(tokenized_obj) - - async def _wait_one_response( - self, - obj: Union[GenerateReqInput, EmbeddingReqInput], - request: Optional[fastapi.Request] = None, - ): - """Wait for the response of one request.""" - state = self.rid_to_state[obj.rid] - - while True: - try: - await asyncio.wait_for(state.event.wait(), timeout=4) - except asyncio.TimeoutError: - if request is not None and await request.is_disconnected(): - self.abort_request(obj.rid) - raise ValueError(f"Abort request {obj.rid}") - continue - - out = state.out_list[-1] - - state.out_list = [] - if state.finished: - if self.log_requests: - max_length = 2048 if self.log_requests_level == 0 else 1 << 30 - msg = f"Finish: obj={dataclass_to_string_truncated(obj, max_length)}, out={dataclass_to_string_truncated(out, max_length)}" - logger.info(msg) - del self.rid_to_state[obj.rid] - - # Check if this was an abort/error created by scheduler - if isinstance(out["meta_info"].get("finish_reason"), dict): - finish_reason = out["meta_info"]["finish_reason"] - if ( - finish_reason.get("type") == "abort" - and finish_reason.get("status_code") == HTTPStatus.BAD_REQUEST - ): - raise ValueError(finish_reason["message"]) - - yield out - break - - state.event.clear() - - if obj.stream: - yield out - else: - if request is not None and await request.is_disconnected(): - self.abort_request(obj.rid) - raise ValueError(f"Abort request {obj.rid}") - - async def _handle_batch_request( - self, - obj: Union[GenerateReqInput, EmbeddingReqInput], - request: Optional[fastapi.Request] = None, - created_time: Optional[float] = None, - ): - batch_size = obj.batch_size - - generators = [] - rids = [] - if getattr(obj, "parallel_sample_num", 1) == 1: - # Send all requests - for i in range(batch_size): - tmp_obj = obj[i] - tokenized_obj = await self._tokenize_one_request(tmp_obj) - self._send_one_request(tmp_obj, tokenized_obj, created_time) - generators.append(self._wait_one_response(tmp_obj, request)) - rids.append(tmp_obj.rid) - else: - # FIXME: When using batch and parallel_sample_num together, the perf is not optimal. - if batch_size > 128: - logger.warning( - "Sending a single large batch with parallel sampling (n > 1) has not been well optimized. " - "The performance might be better if you just duplicate the requests n times or use " - "many threads to send them one by one with parallel sampling (n > 1)." - ) - - # Tokenize all requests - objs = [obj[i] for i in range(batch_size)] - tokenized_objs = await asyncio.gather( - *(self._tokenize_one_request(obj) for obj in objs) - ) - - # Cache the common prefix for parallel sampling - for i in range(batch_size): - tmp_obj = copy.copy(objs[i]) - tokenized_obj = copy.copy(tokenized_objs[i]) - tokenized_obj.rid = tmp_obj.regenerate_rid() - tokenized_obj.sampling_params = copy.copy(tokenized_obj.sampling_params) - tokenized_obj.sampling_params.max_new_tokens = 0 - tokenized_obj.stream = False - self._send_one_request(tmp_obj, tokenized_obj, created_time) - await self._wait_one_response(tmp_obj, request).__anext__() - - # Expand requests, assign new rids for them, and send them - for i in range(batch_size): - for _ in range(obj.parallel_sample_num): - tmp_obj = copy.copy(objs[i]) - tokenized_obj = copy.copy(tokenized_objs[i]) - tokenized_obj.rid = tmp_obj.regenerate_rid() - self._send_one_request(tmp_obj, tokenized_obj, created_time) - generators.append(self._wait_one_response(tmp_obj, request)) - rids.append(tmp_obj.rid) - - # Wait for all requests - is_stream = hasattr(obj, "stream") and obj.stream - if not is_stream: - outputs = await asyncio.gather(*(gen.__anext__() for gen in generators)) - yield outputs - else: - rid_to_index = {rid: i for i, rid in enumerate(rids)} - task_map = {asyncio.create_task(gen.__anext__()): gen for gen in generators} - while task_map: - done, _ = await asyncio.wait( - task_map.keys(), return_when=asyncio.FIRST_COMPLETED - ) - - for task in done: - gen = task_map.pop(task) - try: - result = task.result() - result["index"] = rid_to_index[result["meta_info"]["id"]] - yield result - new_task = asyncio.create_task(gen.__anext__()) - task_map[new_task] = gen - except StopAsyncIteration: - pass + TODO_moved_to_generate def flush_cache(self): req = FlushCacheReq() self.send_to_scheduler.send_pyobj(req) def abort_request(self, rid: str): - if rid not in self.rid_to_state: - return - del self.rid_to_state[rid] - req = AbortReq(rid) - self.send_to_scheduler.send_pyobj(req) + TODO_moved def start_profile(self): req = ProfileReq.START_PROFILE @@ -610,25 +440,6 @@ async def handle_loop(self): recv_obj = await self.recv_from_detokenizer.recv_pyobj() self._result_dispatcher(recv_obj) - def _handle_batch_output( - self, recv_obj: Union[BatchStrOut, BatchEmbeddingOut, BatchTokenIDOut] - ): - for i, rid in enumerate(recv_obj.rids): - state = self.rid_to_state.get(rid, None) - if state is None: - continue - - TODO_moved_postprocesse_response - - state.out_list.append(out_dict) - state.finished = recv_obj.finished_reasons[i] is not None - state.event.set() - - if self.enable_metrics and state.obj.log_metrics: - self.collect_metrics(state, recv_obj, i) - if self.dump_requests_folder and state.finished and state.obj.log_metrics: - self.dump_requests(state, out_dict) - def collect_metrics(self, state: ReqState, recv_obj: BatchStrOut, i: int): TODO_moved From 022eb4fb9febb3599cee05d6dcb0717922c67150 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Sat, 25 Jan 2025 09:35:31 +0800 Subject: [PATCH 017/136] fix import --- python/sglang/srt/managers/generation_manager.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/python/sglang/srt/managers/generation_manager.py b/python/sglang/srt/managers/generation_manager.py index f559b815295..b37326ee279 100644 --- a/python/sglang/srt/managers/generation_manager.py +++ b/python/sglang/srt/managers/generation_manager.py @@ -1,16 +1,23 @@ import asyncio +import copy import dataclasses +import logging import os import time +from http import HTTPStatus from typing import Optional, List, Any, Union, Dict, Callable +import fastapi from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer from sglang.srt.managers.image_processor import get_dummy_image_processor, get_image_processor from sglang.srt.managers.io_struct import GenerateReqInput, EmbeddingReqInput, SessionParams, TokenizedGenerateReqInput, \ - TokenizedEmbeddingReqInput, BatchStrOut, BatchEmbeddingOut, BatchTokenIDOut + TokenizedEmbeddingReqInput, BatchStrOut, BatchEmbeddingOut, BatchTokenIDOut, AbortReq from sglang.srt.metrics.collector import TokenizerMetricsCollector from sglang.srt.sampling.sampling_params import SamplingParams from sglang.srt.server_args import ServerArgs +from sglang.srt.utils import dataclass_to_string_truncated + +logger = logging.getLogger(__name__) class GenerationManager: @@ -78,7 +85,7 @@ def _send_one_request( created_time: Optional[float] = None, ): event = asyncio.Event() - state = ReqState([], False, event, obj, created_time=created_time) + state = _ReqState([], False, event, obj, created_time=created_time) self.rid_to_state[obj.rid] = state self.send_to_scheduler.send_pyobj(tokenized_obj) From dc53f8fb9ce559c57acf42375602da7ea1781d84 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Sat, 25 Jan 2025 09:36:05 +0800 Subject: [PATCH 018/136] mv modelconfig --- python/sglang/srt/managers/generation_manager.py | 14 ++++++++++++++ python/sglang/srt/managers/tokenizer_manager.py | 10 ---------- 2 files changed, 14 insertions(+), 10 deletions(-) diff --git a/python/sglang/srt/managers/generation_manager.py b/python/sglang/srt/managers/generation_manager.py index b37326ee279..d190e25b26e 100644 --- a/python/sglang/srt/managers/generation_manager.py +++ b/python/sglang/srt/managers/generation_manager.py @@ -8,6 +8,7 @@ from typing import Optional, List, Any, Union, Dict, Callable import fastapi +from sglang.srt.configs.model_config import ModelConfig from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer from sglang.srt.managers.image_processor import get_dummy_image_processor, get_image_processor from sglang.srt.managers.io_struct import GenerateReqInput, EmbeddingReqInput, SessionParams, TokenizedGenerateReqInput, \ @@ -511,6 +512,19 @@ def _detokenize_top_logprobs_tokens( return ret +def _compute_model_config(server_args: ServerArgs): + return ModelConfig( + server_args.model_path, + trust_remote_code=server_args.trust_remote_code, + revision=server_args.revision, + context_length=server_args.context_length, + model_override_args=server_args.json_model_override_args, + is_embedding=server_args.is_embedding, + dtype=server_args.dtype, + quantization=server_args.quantization, + ) + + class _MetricManager: def __init__(self, server_args: ServerArgs): self.metrics_collector = TokenizerMetricsCollector( diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py index 1517254af37..98b8ec75e03 100644 --- a/python/sglang/srt/managers/tokenizer_manager.py +++ b/python/sglang/srt/managers/tokenizer_manager.py @@ -114,16 +114,6 @@ def __init__( # Read model args self.model_path = server_args.model_path self.served_model_name = server_args.served_model_name - self.model_config = ModelConfig( - server_args.model_path, - trust_remote_code=server_args.trust_remote_code, - revision=server_args.revision, - context_length=server_args.context_length, - model_override_args=server_args.json_model_override_args, - is_embedding=server_args.is_embedding, - dtype=server_args.dtype, - quantization=server_args.quantization, - ) self.is_generation = self.model_config.is_generation self.context_len = self.model_config.context_len From c4f166809e50408f797a207c922ec4cd5babcfbf Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Sat, 25 Jan 2025 09:36:30 +0800 Subject: [PATCH 019/136] call generation_converter --- python/sglang/srt/managers/generation_manager.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/python/sglang/srt/managers/generation_manager.py b/python/sglang/srt/managers/generation_manager.py index d190e25b26e..e4aa4f0dfe3 100644 --- a/python/sglang/srt/managers/generation_manager.py +++ b/python/sglang/srt/managers/generation_manager.py @@ -224,7 +224,9 @@ def _handle_batch_output( if state is None: continue - TODO_moved_postprocesse_response + out_dict = self._generation_converter.postprocess_response( + recv_obj, index, state.obj + ) state.out_list.append(out_dict) state.finished = recv_obj.finished_reasons[i] is not None From 1670ce16babc2c7851c3f362cc4ec1e5e3bca9f2 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Sat, 25 Jan 2025 09:37:18 +0800 Subject: [PATCH 020/136] fix metrics --- .../sglang/srt/managers/generation_manager.py | 19 ++++++++++++++----- 1 file changed, 14 insertions(+), 5 deletions(-) diff --git a/python/sglang/srt/managers/generation_manager.py b/python/sglang/srt/managers/generation_manager.py index e4aa4f0dfe3..a585ec6ce69 100644 --- a/python/sglang/srt/managers/generation_manager.py +++ b/python/sglang/srt/managers/generation_manager.py @@ -86,7 +86,7 @@ def _send_one_request( created_time: Optional[float] = None, ): event = asyncio.Event() - state = _ReqState([], False, event, obj, created_time=created_time) + state = _ReqState([], False, event, obj, metric=_MetricReqState(created_time=created_time)) self.rid_to_state[obj.rid] = state self.send_to_scheduler.send_pyobj(tokenized_obj) @@ -219,7 +219,7 @@ async def _handle_batch_request( def _handle_batch_output( self, recv_obj: Union[BatchStrOut, BatchEmbeddingOut, BatchTokenIDOut] ): - for i, rid in enumerate(recv_obj.rids): + for index, rid in enumerate(recv_obj.rids): state = self.rid_to_state.get(rid, None) if state is None: continue @@ -229,11 +229,18 @@ def _handle_batch_output( ) state.out_list.append(out_dict) - state.finished = recv_obj.finished_reasons[i] is not None + state.finished = recv_obj.finished_reasons[index] is not None state.event.set() - if self.enable_metrics and state.obj.log_metrics: - self.collect_metrics(state, recv_obj, i) + if self._metric_manager: + self._metric_manager.handle_batch_output_metrics( + recv_obj, + index, + state.metric, + finished=state.finished, + stream=state.obj.stream if hasattr(state.obj, "stream") else None, + ) + if self.dump_requests_folder and state.finished and state.obj.log_metrics: self.dump_requests(state, out_dict) @@ -585,6 +592,8 @@ class _ReqState: event: asyncio.Event obj: Any + metric: "_MetricReqState" + # For streaming output last_output_offset: int = 0 From 905d24716e01324450ce899753bcaab75532342a Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Sat, 25 Jan 2025 09:38:01 +0800 Subject: [PATCH 021/136] fix err --- .../sglang/srt/managers/generation_manager.py | 24 +++++++++---------- 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/python/sglang/srt/managers/generation_manager.py b/python/sglang/srt/managers/generation_manager.py index a585ec6ce69..fba14673aae 100644 --- a/python/sglang/srt/managers/generation_manager.py +++ b/python/sglang/srt/managers/generation_manager.py @@ -411,17 +411,17 @@ def postprocess_response( if isinstance(recv_obj, BatchStrOut): return { - "text": recv_obj.output_strs[i], + "text": recv_obj.output_strs[index], "meta_info": meta_info, } elif isinstance(recv_obj, BatchTokenIDOut): return { - "token_ids": recv_obj.output_ids[i], + "token_ids": recv_obj.output_ids[index], "meta_info": meta_info, } elif isinstance(recv_obj, BatchEmbeddingOut): return { - "embedding": recv_obj.embeddings[i], + "embedding": recv_obj.embeddings[index], "meta_info": meta_info, } else: @@ -429,25 +429,25 @@ def postprocess_response( def _compute_meta_info(self, index, recv_obj, req_obj): meta_info = { - "id": rid, - "finish_reason": recv_obj.finished_reasons[i], - "prompt_tokens": recv_obj.prompt_tokens[i], + "id": recv_obj.rids[index], + "finish_reason": recv_obj.finished_reasons[index], + "prompt_tokens": recv_obj.prompt_tokens[index], } - if getattr(state.obj, "return_logprob", False): + if getattr(req_obj, "return_logprob", False): self._convert_logprob_style( meta_info, - state.obj.top_logprobs_num, - state.obj.return_text_in_logprobs, + req_obj.top_logprobs_num, + req_obj.return_text_in_logprobs, recv_obj, - i, + index, ) if not isinstance(recv_obj, BatchEmbeddingOut): meta_info.update( { - "completion_tokens": recv_obj.completion_tokens[i], - "cached_tokens": recv_obj.cached_tokens[i], + "completion_tokens": recv_obj.completion_tokens[index], + "cached_tokens": recv_obj.cached_tokens[index], } ) From 41bee7dd59779a48cc92715c381df2e016cab7de Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Sat, 25 Jan 2025 09:39:26 +0800 Subject: [PATCH 022/136] handle tokenizer_manager.generate_request --- .../sglang/srt/managers/generation_manager.py | 25 ++++++++----------- .../sglang/srt/managers/tokenizer_manager.py | 5 +++- 2 files changed, 15 insertions(+), 15 deletions(-) diff --git a/python/sglang/srt/managers/generation_manager.py b/python/sglang/srt/managers/generation_manager.py index fba14673aae..f0186816c41 100644 --- a/python/sglang/srt/managers/generation_manager.py +++ b/python/sglang/srt/managers/generation_manager.py @@ -50,8 +50,6 @@ async def generate( ): created_time = time.time() - self.auto_create_handle_loop() - if isinstance(obj, EmbeddingReqInput) and self.is_generation: raise ValueError( "This model does not appear to be an embedding model by default. " @@ -66,18 +64,17 @@ async def generate( f"Receive: obj={dataclass_to_string_truncated(obj, max_length)}" ) - async with self.model_update_lock.reader_lock: - is_single = obj.is_single - if is_single: - tokenized_obj = await self._tokenize_one_request(obj) - self._send_one_request(obj, tokenized_obj, created_time) - async for response in self._wait_one_response(obj, request): - yield response - else: - async for response in self._handle_batch_request( - obj, request, created_time - ): - yield response + is_single = obj.is_single + if is_single: + tokenized_obj = await self._tokenize_one_request(obj) + self._send_one_request(obj, tokenized_obj, created_time) + async for response in self._wait_one_response(obj, request): + yield response + else: + async for response in self._handle_batch_request( + obj, request, created_time + ): + yield response def _send_one_request( self, diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py index 98b8ec75e03..70567581bbc 100644 --- a/python/sglang/srt/managers/tokenizer_manager.py +++ b/python/sglang/srt/managers/tokenizer_manager.py @@ -201,7 +201,10 @@ async def generate_request( obj: Union[GenerateReqInput, EmbeddingReqInput], request: Optional[fastapi.Request] = None, ): - TODO_moved_to_generate + self.auto_create_handle_loop() + async with self.model_update_lock.reader_lock: + async for value in self._generation_manager.generate(obj, request): + yield value def flush_cache(self): req = FlushCacheReq() From 2b3ca9686c31d6ebf8f30e8cb32a3ea74d000c8d Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Sat, 25 Jan 2025 09:39:42 +0800 Subject: [PATCH 023/136] handle abort_request --- python/sglang/srt/managers/tokenizer_manager.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py index 70567581bbc..fea1cb0743a 100644 --- a/python/sglang/srt/managers/tokenizer_manager.py +++ b/python/sglang/srt/managers/tokenizer_manager.py @@ -211,7 +211,7 @@ def flush_cache(self): self.send_to_scheduler.send_pyobj(req) def abort_request(self, rid: str): - TODO_moved + self._generation_manager.abort_request(rid) def start_profile(self): req = ProfileReq.START_PROFILE From e293f1fd6382e75372856747591a518f6a94a491 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Sat, 25 Jan 2025 09:40:06 +0800 Subject: [PATCH 024/136] add field --- python/sglang/srt/managers/tokenizer_manager.py | 15 ++++++--------- 1 file changed, 6 insertions(+), 9 deletions(-) diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py index fea1cb0743a..e17873148e8 100644 --- a/python/sglang/srt/managers/tokenizer_manager.py +++ b/python/sglang/srt/managers/tokenizer_manager.py @@ -35,12 +35,7 @@ from fastapi import BackgroundTasks from sglang.srt.aio_rwlock import RWLock -from sglang.srt.configs.model_config import ModelConfig -from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer -from sglang.srt.managers.image_processor import ( - get_dummy_image_processor, - get_image_processor, -) +from sglang.srt.managers.generation_manager import GenerationManager from sglang.srt.managers.io_struct import ( AbortReq, BatchEmbeddingOut, @@ -62,9 +57,7 @@ ReleaseMemoryOccupationReqOutput, ResumeMemoryOccupationReqInput, ResumeMemoryOccupationReqOutput, - SessionParams, TokenizedEmbeddingReqInput, - TokenizedGenerateReqInput, UpdateWeightFromDiskReqInput, UpdateWeightFromDiskReqOutput, UpdateWeightsFromDistributedReqInput, @@ -73,7 +66,6 @@ UpdateWeightsFromTensorReqOutput, ) from sglang.srt.metrics.collector import TokenizerMetricsCollector -from sglang.srt.sampling.sampling_params import SamplingParams from sglang.srt.server_args import PortArgs, ServerArgs from sglang.srt.utils import ( dataclass_to_string_truncated, @@ -135,6 +127,11 @@ def __init__( # For session info self.session_futures = {} # session_id -> asyncio event + self._generation_manager = GenerationManager( + server_args=server_args, + on_request=self.send_to_scheduler.send_pyobj, + ) + # Others self.gracefully_exit = False self.init_weights_update_group_communicator = _Communicator( From 2424cf2254ae5010f00020f23b9f4db5245dd72b Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Sat, 25 Jan 2025 09:40:15 +0800 Subject: [PATCH 025/136] rm empty func --- python/sglang/srt/managers/tokenizer_manager.py | 13 +------------ 1 file changed, 1 insertion(+), 12 deletions(-) diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py index e17873148e8..44ba4b114d8 100644 --- a/python/sglang/srt/managers/tokenizer_manager.py +++ b/python/sglang/srt/managers/tokenizer_manager.py @@ -14,8 +14,6 @@ """TokenizerManager is a process that tokenizes the text.""" import asyncio -import copy -import dataclasses import logging import os import pickle @@ -25,19 +23,16 @@ import time import uuid from datetime import datetime -from http import HTTPStatus -from typing import Any, Awaitable, Dict, Generic, List, Optional, Tuple, TypeVar, Union +from typing import Awaitable, Generic, List, Optional, Tuple, TypeVar, Union import fastapi import uvloop import zmq import zmq.asyncio from fastapi import BackgroundTasks - from sglang.srt.aio_rwlock import RWLock from sglang.srt.managers.generation_manager import GenerationManager from sglang.srt.managers.io_struct import ( - AbortReq, BatchEmbeddingOut, BatchStrOut, BatchTokenIDOut, @@ -57,7 +52,6 @@ ReleaseMemoryOccupationReqOutput, ResumeMemoryOccupationReqInput, ResumeMemoryOccupationReqOutput, - TokenizedEmbeddingReqInput, UpdateWeightFromDiskReqInput, UpdateWeightFromDiskReqOutput, UpdateWeightsFromDistributedReqInput, @@ -65,10 +59,8 @@ UpdateWeightsFromTensorReqInput, UpdateWeightsFromTensorReqOutput, ) -from sglang.srt.metrics.collector import TokenizerMetricsCollector from sglang.srt.server_args import PortArgs, ServerArgs from sglang.srt.utils import ( - dataclass_to_string_truncated, get_zmq_socket, kill_process_tree, ) @@ -430,9 +422,6 @@ async def handle_loop(self): recv_obj = await self.recv_from_detokenizer.recv_pyobj() self._result_dispatcher(recv_obj) - def collect_metrics(self, state: ReqState, recv_obj: BatchStrOut, i: int): - TODO_moved - def dump_requests(self, state: ReqState, out_dict: dict): self.dump_request_list.append( (state.obj, out_dict, state.created_time, time.time()) From 422ea33c16c29c73143ea898bea4a6af97423adb Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Sat, 25 Jan 2025 09:41:18 +0800 Subject: [PATCH 026/136] extract _RequestDumper --- .../sglang/srt/managers/generation_manager.py | 34 ++++++++++++++++++- .../sglang/srt/managers/tokenizer_manager.py | 29 ---------------- 2 files changed, 33 insertions(+), 30 deletions(-) diff --git a/python/sglang/srt/managers/generation_manager.py b/python/sglang/srt/managers/generation_manager.py index f0186816c41..d20d8721008 100644 --- a/python/sglang/srt/managers/generation_manager.py +++ b/python/sglang/srt/managers/generation_manager.py @@ -3,9 +3,11 @@ import dataclasses import logging import os +import pickle import time +from datetime import datetime from http import HTTPStatus -from typing import Optional, List, Any, Union, Dict, Callable +from typing import Optional, List, Any, Union, Dict, Callable, Tuple import fastapi from sglang.srt.configs.model_config import ModelConfig @@ -580,6 +582,36 @@ def handle_batch_output_metrics( ) +class _RequestDumper: + def __init__(self): + self.dump_requests_folder = "" # By default do not dump + self.dump_requests_threshold = 1000 + self.dump_request_list: List[Tuple] = [] + + def dump_requests(self, state: '_ReqState', out_dict: dict): + self.dump_request_list.append( + (state.obj, out_dict, state.created_time, time.time()) + ) + + if len(self.dump_request_list) >= self.dump_requests_threshold: + filename = os.path.join( + self.dump_requests_folder, + datetime.now().strftime("%Y-%m-%d_%H-%M-%S") + ".pkl", + ) + logger.info(f"Dump {len(self.dump_request_list)} requests to {filename}") + + to_dump = self.dump_request_list + self.dump_request_list = [] + + def background_task(): + os.makedirs(self.dump_requests_folder, exist_ok=True) + with open(filename, "wb") as f: + pickle.dump(to_dump, f) + + # Schedule the task to run in the background without awaiting it + asyncio.create_task(asyncio.to_thread(background_task)) + + @dataclasses.dataclass class _ReqState: """Store the state a request.""" diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py index 44ba4b114d8..a25160e7142 100644 --- a/python/sglang/srt/managers/tokenizer_manager.py +++ b/python/sglang/srt/managers/tokenizer_manager.py @@ -16,13 +16,10 @@ import asyncio import logging import os -import pickle import signal import sys import threading -import time import uuid -from datetime import datetime from typing import Awaitable, Generic, List, Optional, Tuple, TypeVar, Union import fastapi @@ -105,9 +102,6 @@ def __init__( # Store states self.no_create_loop = False - self.dump_requests_folder = "" # By default do not dump - self.dump_requests_threshold = 1000 - self.dump_request_list: List[Tuple] = [] # The event to notify the weight sync is finished. self.model_update_lock = RWLock() @@ -422,29 +416,6 @@ async def handle_loop(self): recv_obj = await self.recv_from_detokenizer.recv_pyobj() self._result_dispatcher(recv_obj) - def dump_requests(self, state: ReqState, out_dict: dict): - self.dump_request_list.append( - (state.obj, out_dict, state.created_time, time.time()) - ) - - if len(self.dump_request_list) >= self.dump_requests_threshold: - filename = os.path.join( - self.dump_requests_folder, - datetime.now().strftime("%Y-%m-%d_%H-%M-%S") + ".pkl", - ) - logger.info(f"Dump {len(self.dump_request_list)} requests to {filename}") - - to_dump = self.dump_request_list - self.dump_request_list = [] - - def background_task(): - os.makedirs(self.dump_requests_folder, exist_ok=True) - with open(filename, "wb") as f: - pickle.dump(to_dump, f) - - # Schedule the task to run in the background without awaiting it - asyncio.create_task(asyncio.to_thread(background_task)) - def _handle_open_session_req_output(self, recv_obj): self.session_futures[recv_obj.session_id].set_result( recv_obj.session_id if recv_obj.success else None From 3e6e36360246359ed8e6f51906cf1cce7c91ab40 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Sat, 25 Jan 2025 09:42:45 +0800 Subject: [PATCH 027/136] call setup --- python/sglang/srt/managers/generation_manager.py | 2 ++ python/sglang/srt/managers/tokenizer_manager.py | 4 ++-- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/python/sglang/srt/managers/generation_manager.py b/python/sglang/srt/managers/generation_manager.py index d20d8721008..e7ebdcb761a 100644 --- a/python/sglang/srt/managers/generation_manager.py +++ b/python/sglang/srt/managers/generation_manager.py @@ -45,6 +45,8 @@ def __init__( else: self._metric_manager = None + self.request_dumper = _RequestDumper() + async def generate( self, obj: Union[GenerateReqInput, EmbeddingReqInput], diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py index a25160e7142..84e7776e9da 100644 --- a/python/sglang/srt/managers/tokenizer_manager.py +++ b/python/sglang/srt/managers/tokenizer_manager.py @@ -347,9 +347,9 @@ def configure_logging(self, obj: ConfigureLoggingReq): if obj.log_requests_level is not None: self.log_requests_level = obj.log_requests_level if obj.dump_requests_folder is not None: - self.dump_requests_folder = obj.dump_requests_folder + self._generation_manager.request_dumper.dump_requests_folder = obj.dump_requests_folder if obj.dump_requests_threshold is not None: - self.dump_requests_threshold = obj.dump_requests_threshold + self._generation_manager.request_dumper.dump_requests_threshold = obj.dump_requests_threshold logging.info(f"Config logging: {obj=}") def create_abort_task(self, obj: GenerateReqInput): From 56dcbd1f7c30d0dffcd8552ab4fb469380ab43e8 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Sat, 25 Jan 2025 09:44:22 +0800 Subject: [PATCH 028/136] call handle_batch_output --- python/sglang/srt/managers/generation_manager.py | 2 +- python/sglang/srt/managers/tokenizer_manager.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/python/sglang/srt/managers/generation_manager.py b/python/sglang/srt/managers/generation_manager.py index e7ebdcb761a..fe008d36610 100644 --- a/python/sglang/srt/managers/generation_manager.py +++ b/python/sglang/srt/managers/generation_manager.py @@ -217,7 +217,7 @@ async def _handle_batch_request( except StopAsyncIteration: pass - def _handle_batch_output( + def handle_batch_output( self, recv_obj: Union[BatchStrOut, BatchEmbeddingOut, BatchTokenIDOut] ): for index, rid in enumerate(recv_obj.rids): diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py index 84e7776e9da..eeaf70f21cd 100644 --- a/python/sglang/srt/managers/tokenizer_manager.py +++ b/python/sglang/srt/managers/tokenizer_manager.py @@ -145,7 +145,7 @@ def __init__( [ ( (BatchStrOut, BatchEmbeddingOut, BatchTokenIDOut), - self._handle_batch_output, + self._generation_manager.handle_batch_output, ), (OpenSessionReqOutput, self._handle_open_session_req_output), ( From 0c08f3092c23de13b60ec47cea9e8ec51e4ddbb7 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Sat, 25 Jan 2025 09:44:56 +0800 Subject: [PATCH 029/136] more tokenizer_manager call generation_manager --- python/sglang/srt/managers/tokenizer_manager.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py index eeaf70f21cd..1bf6ce3bbc3 100644 --- a/python/sglang/srt/managers/tokenizer_manager.py +++ b/python/sglang/srt/managers/tokenizer_manager.py @@ -397,7 +397,7 @@ async def sigterm_watchdog(self): # Drain requests while True: - remain_num_req = len(self.rid_to_state) + remain_num_req = len(self._generation_manager.rid_to_state) logger.info( f"Gracefully exiting... remaining number of requests {remain_num_req}" ) From deec6aff5437a0e53f52666970edd3769495d638 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Sat, 25 Jan 2025 09:45:45 +0800 Subject: [PATCH 030/136] use property --- python/sglang/srt/managers/tokenizer_manager.py | 16 ++++++++++++---- 1 file changed, 12 insertions(+), 4 deletions(-) diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py index 1bf6ce3bbc3..f4785ce7610 100644 --- a/python/sglang/srt/managers/tokenizer_manager.py +++ b/python/sglang/srt/managers/tokenizer_manager.py @@ -96,10 +96,6 @@ def __init__( self.model_path = server_args.model_path self.served_model_name = server_args.served_model_name - self.is_generation = self.model_config.is_generation - self.context_len = self.model_config.context_len - self.image_token_id = self.model_config.image_token_id - # Store states self.no_create_loop = False @@ -430,6 +426,18 @@ def _handle_update_weights_from_disk_req_output(self, recv_obj): if len(self.model_update_tmp) == self.server_args.dp_size: self.model_update_result.set_result(self.model_update_tmp) + @property + def is_generation(self): + return self._generation_manager.model_config.is_generation + + @property + def tokenizer(self): + return self._generation_manager.tokenizer + + @property + def image_token_id(self): + return self._generation_manager.model_config.image_token_id + async def print_exception_wrapper(func): """ From 43dd4e240decbccdbd5a3330e82aa4edb7b1bb4e Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Sat, 25 Jan 2025 09:46:53 +0800 Subject: [PATCH 031/136] call request_dumper --- python/sglang/srt/managers/generation_manager.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/python/sglang/srt/managers/generation_manager.py b/python/sglang/srt/managers/generation_manager.py index fe008d36610..3656e459ff8 100644 --- a/python/sglang/srt/managers/generation_manager.py +++ b/python/sglang/srt/managers/generation_manager.py @@ -242,8 +242,7 @@ def handle_batch_output( stream=state.obj.stream if hasattr(state.obj, "stream") else None, ) - if self.dump_requests_folder and state.finished and state.obj.log_metrics: - self.dump_requests(state, out_dict) + self.request_dumper.maybe_dump_requests(state=state, out_dict=out_dict) def abort_request(self, rid: str): if rid not in self.rid_to_state: @@ -590,7 +589,11 @@ def __init__(self): self.dump_requests_threshold = 1000 self.dump_request_list: List[Tuple] = [] - def dump_requests(self, state: '_ReqState', out_dict: dict): + def maybe_dump_requests(self, state: '_ReqState', out_dict: dict): + if self.dump_requests_folder and state.finished and state.obj.log_metrics: + self._dump_requests(state, out_dict) + + def _dump_requests(self, state: '_ReqState', out_dict: dict): self.dump_request_list.append( (state.obj, out_dict, state.created_time, time.time()) ) From 2d09b582f893ec8478d19b9bbee17f6b172c293f Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Sat, 25 Jan 2025 09:47:22 +0800 Subject: [PATCH 032/136] call on_request --- python/sglang/srt/managers/generation_manager.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/sglang/srt/managers/generation_manager.py b/python/sglang/srt/managers/generation_manager.py index 3656e459ff8..4fabcbc411f 100644 --- a/python/sglang/srt/managers/generation_manager.py +++ b/python/sglang/srt/managers/generation_manager.py @@ -89,7 +89,7 @@ def _send_one_request( event = asyncio.Event() state = _ReqState([], False, event, obj, metric=_MetricReqState(created_time=created_time)) self.rid_to_state[obj.rid] = state - self.send_to_scheduler.send_pyobj(tokenized_obj) + self.on_request(tokenized_obj) async def _wait_one_response( self, @@ -249,7 +249,7 @@ def abort_request(self, rid: str): return del self.rid_to_state[rid] req = AbortReq(rid) - self.send_to_scheduler.send_pyobj(req) + self.on_request(req) @property def tokenizer(self): From 5701e2093410f41a1987d2a8629ba85a34149715 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Sat, 25 Jan 2025 09:48:18 +0800 Subject: [PATCH 033/136] fix minor field names --- python/sglang/srt/managers/generation_manager.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/python/sglang/srt/managers/generation_manager.py b/python/sglang/srt/managers/generation_manager.py index 4fabcbc411f..9e7c7128d78 100644 --- a/python/sglang/srt/managers/generation_manager.py +++ b/python/sglang/srt/managers/generation_manager.py @@ -70,7 +70,7 @@ async def generate( is_single = obj.is_single if is_single: - tokenized_obj = await self._tokenize_one_request(obj) + tokenized_obj = await self._generation_converter.tokenize_request(obj) self._send_one_request(obj, tokenized_obj, created_time) async for response in self._wait_one_response(obj, request): yield response @@ -153,7 +153,7 @@ async def _handle_batch_request( # Send all requests for i in range(batch_size): tmp_obj = obj[i] - tokenized_obj = await self._tokenize_one_request(tmp_obj) + tokenized_obj = await self._generation_converter.tokenize_request(tmp_obj) self._send_one_request(tmp_obj, tokenized_obj, created_time) generators.append(self._wait_one_response(tmp_obj, request)) rids.append(tmp_obj.rid) @@ -169,7 +169,7 @@ async def _handle_batch_request( # Tokenize all requests objs = [obj[i] for i in range(batch_size)] tokenized_objs = await asyncio.gather( - *(self._tokenize_one_request(obj) for obj in objs) + *(self._generation_converter.tokenize_request(obj) for obj in objs) ) # Cache the common prefix for parallel sampling @@ -595,7 +595,7 @@ def maybe_dump_requests(self, state: '_ReqState', out_dict: dict): def _dump_requests(self, state: '_ReqState', out_dict: dict): self.dump_request_list.append( - (state.obj, out_dict, state.created_time, time.time()) + (state.obj, out_dict, state.metric.created_time, time.time()) ) if len(self.dump_request_list) >= self.dump_requests_threshold: From cff89f0ae956cd64c441d2ba50c25a1de0a661e2 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Sat, 25 Jan 2025 09:49:05 +0800 Subject: [PATCH 034/136] fix more field names --- python/sglang/srt/managers/generation_manager.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/sglang/srt/managers/generation_manager.py b/python/sglang/srt/managers/generation_manager.py index 9e7c7128d78..3a23ac0a90d 100644 --- a/python/sglang/srt/managers/generation_manager.py +++ b/python/sglang/srt/managers/generation_manager.py @@ -54,7 +54,7 @@ async def generate( ): created_time = time.time() - if isinstance(obj, EmbeddingReqInput) and self.is_generation: + if isinstance(obj, EmbeddingReqInput) and self.model_config.is_generation: raise ValueError( "This model does not appear to be an embedding model by default. " "Please add `--is-embedding` when launching the server or try another model." @@ -323,7 +323,7 @@ async def tokenize_request( ) input_ids = self.tokenizer.encode(input_text) - if self.is_generation: + if self.model_config.is_generation: # TODO: also support getting embeddings for multimodal models image_inputs: Dict = await self.image_processor.process_images_async( obj.image_data, input_text or input_ids, obj, self.max_req_input_len From ba0f1b1777f21cb55bf4a63af13e1523cea57984 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Sat, 25 Jan 2025 09:49:40 +0800 Subject: [PATCH 035/136] more --- python/sglang/srt/managers/generation_manager.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/sglang/srt/managers/generation_manager.py b/python/sglang/srt/managers/generation_manager.py index 3a23ac0a90d..246b9f92595 100644 --- a/python/sglang/srt/managers/generation_manager.py +++ b/python/sglang/srt/managers/generation_manager.py @@ -62,7 +62,7 @@ async def generate( obj.normalize_batch_and_arguments() - if self.log_requests: + if self.server_args.log_requests: max_length = 2048 if self.log_requests_level == 0 else 1 << 30 logger.info( f"Receive: obj={dataclass_to_string_truncated(obj, max_length)}" @@ -112,7 +112,7 @@ async def _wait_one_response( state.out_list = [] if state.finished: - if self.log_requests: + if self.server_args.log_requests: max_length = 2048 if self.log_requests_level == 0 else 1 << 30 msg = f"Finish: obj={dataclass_to_string_truncated(obj, max_length)}, out={dataclass_to_string_truncated(out, max_length)}" logger.info(msg) From 4b03255d8b95dd77632b8b06cbc321f3e6d3c083 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Sat, 25 Jan 2025 09:50:41 +0800 Subject: [PATCH 036/136] extract _RequestLogger --- python/sglang/srt/managers/generation_manager.py | 7 +++++++ python/sglang/srt/managers/tokenizer_manager.py | 6 ++---- 2 files changed, 9 insertions(+), 4 deletions(-) diff --git a/python/sglang/srt/managers/generation_manager.py b/python/sglang/srt/managers/generation_manager.py index 246b9f92595..6a2ed637a13 100644 --- a/python/sglang/srt/managers/generation_manager.py +++ b/python/sglang/srt/managers/generation_manager.py @@ -45,6 +45,7 @@ def __init__( else: self._metric_manager = None + self.request_logger = _RequestLogger() self.request_dumper = _RequestDumper() async def generate( @@ -583,6 +584,12 @@ def handle_batch_output_metrics( ) +class _RequestLogger: + def __init__(self, server_args: ServerArgs): + self.log_requests = server_args.log_requests + self.log_requests_level = 0 + + class _RequestDumper: def __init__(self): self.dump_requests_folder = "" # By default do not dump diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py index f4785ce7610..83ebaa25630 100644 --- a/python/sglang/srt/managers/tokenizer_manager.py +++ b/python/sglang/srt/managers/tokenizer_manager.py @@ -80,8 +80,6 @@ def __init__( self.server_args = server_args self.enable_metrics = server_args.enable_metrics - self.log_requests = server_args.log_requests - self.log_requests_level = 0 # Init inter-process communication context = zmq.asyncio.Context(2) @@ -339,9 +337,9 @@ async def close_session( def configure_logging(self, obj: ConfigureLoggingReq): if obj.log_requests is not None: - self.log_requests = obj.log_requests + self._generation_manager.request_logger.log_requests = obj.log_requests if obj.log_requests_level is not None: - self.log_requests_level = obj.log_requests_level + self._generation_manager.request_logger.log_requests_level = obj.log_requests_level if obj.dump_requests_folder is not None: self._generation_manager.request_dumper.dump_requests_folder = obj.dump_requests_folder if obj.dump_requests_threshold is not None: From 75dc737a92fd2e6cd9df7d978c4b5d2e6db957a0 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Sat, 25 Jan 2025 09:52:42 +0800 Subject: [PATCH 037/136] extract logger body --- .../sglang/srt/managers/generation_manager.py | 24 ++++++++++++------- 1 file changed, 15 insertions(+), 9 deletions(-) diff --git a/python/sglang/srt/managers/generation_manager.py b/python/sglang/srt/managers/generation_manager.py index 6a2ed637a13..ea951e467a2 100644 --- a/python/sglang/srt/managers/generation_manager.py +++ b/python/sglang/srt/managers/generation_manager.py @@ -63,11 +63,7 @@ async def generate( obj.normalize_batch_and_arguments() - if self.server_args.log_requests: - max_length = 2048 if self.log_requests_level == 0 else 1 << 30 - logger.info( - f"Receive: obj={dataclass_to_string_truncated(obj, max_length)}" - ) + self.request_logger.log_generate(obj) is_single = obj.is_single if is_single: @@ -113,10 +109,7 @@ async def _wait_one_response( state.out_list = [] if state.finished: - if self.server_args.log_requests: - max_length = 2048 if self.log_requests_level == 0 else 1 << 30 - msg = f"Finish: obj={dataclass_to_string_truncated(obj, max_length)}, out={dataclass_to_string_truncated(out, max_length)}" - logger.info(msg) + self.request_logger.log_response(obj, out) del self.rid_to_state[obj.rid] # Check if this was an abort/error created by scheduler @@ -589,6 +582,19 @@ def __init__(self, server_args: ServerArgs): self.log_requests = server_args.log_requests self.log_requests_level = 0 + def log_generate(self, obj): + if self.log_requests: + max_length = 2048 if self.log_requests_level == 0 else 1 << 30 + logger.info( + f"Receive: obj={dataclass_to_string_truncated(obj, max_length)}" + ) + + def log_response(self, obj, out): + if self.log_requests: + max_length = 2048 if self.log_requests_level == 0 else 1 << 30 + msg = f"Finish: obj={dataclass_to_string_truncated(obj, max_length)}, out={dataclass_to_string_truncated(out, max_length)}" + logger.info(msg) + class _RequestDumper: def __init__(self): From 4100d6086d8e2c68a9c3803870ca20b2f13ad5ca Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Sat, 25 Jan 2025 09:53:00 +0800 Subject: [PATCH 038/136] fix err --- python/sglang/srt/managers/generation_manager.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/sglang/srt/managers/generation_manager.py b/python/sglang/srt/managers/generation_manager.py index ea951e467a2..b8db3200f93 100644 --- a/python/sglang/srt/managers/generation_manager.py +++ b/python/sglang/srt/managers/generation_manager.py @@ -45,7 +45,7 @@ def __init__( else: self._metric_manager = None - self.request_logger = _RequestLogger() + self.request_logger = _RequestLogger(server_args) self.request_dumper = _RequestDumper() async def generate( From ba4ad8eb48a7c93e07544344d044fbf759886e22 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Sat, 25 Jan 2025 09:56:12 +0800 Subject: [PATCH 039/136] fix field --- python/sglang/srt/managers/generation_manager.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/python/sglang/srt/managers/generation_manager.py b/python/sglang/srt/managers/generation_manager.py index b8db3200f93..8d35ac1b23c 100644 --- a/python/sglang/srt/managers/generation_manager.py +++ b/python/sglang/srt/managers/generation_manager.py @@ -332,20 +332,20 @@ async def tokenize_request( ) input_token_num = len(input_ids) if input_ids is not None else 0 - if input_token_num >= self.context_len: + if input_token_num >= self.model_config.context_len: raise ValueError( f"The input ({input_token_num} tokens) is longer than the " - f"model's context length ({self.context_len} tokens)." + f"model's context length ({self.model_config.context_len} tokens)." ) if ( obj.sampling_params.get("max_new_tokens") is not None and obj.sampling_params.get("max_new_tokens") + input_token_num - >= self.context_len + >= self.model_config.context_len ): raise ValueError( f"Requested token count exceeds the model's maximum context length " - f"of {self.context_len} tokens. You requested a total of " + f"of {self.model_config.context_len} tokens. You requested a total of " f"{obj.sampling_params.get('max_new_tokens') + input_token_num} " f"tokens: {input_token_num} tokens from the input messages and " f"{obj.sampling_params.get('max_new_tokens')} tokens for the " From 5450fa3e3a7d07b6c4bdd6da92ac2cbf5e0c9030 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Sat, 25 Jan 2025 09:59:29 +0800 Subject: [PATCH 040/136] empty package --- python/sglang/srt/orchestration/__init__.py | 0 python/sglang/srt/orchestration/std/__init__.py | 0 2 files changed, 0 insertions(+), 0 deletions(-) create mode 100644 python/sglang/srt/orchestration/__init__.py create mode 100644 python/sglang/srt/orchestration/std/__init__.py diff --git a/python/sglang/srt/orchestration/__init__.py b/python/sglang/srt/orchestration/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/python/sglang/srt/orchestration/std/__init__.py b/python/sglang/srt/orchestration/std/__init__.py new file mode 100644 index 00000000000..e69de29bb2d From 9080d45650ae461ce28866bb8c37cb5d6c02279d Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Sat, 25 Jan 2025 10:00:14 +0800 Subject: [PATCH 041/136] fmt --- .../sglang/srt/managers/generation_manager.py | 33 ++++++++++++++----- .../sglang/srt/managers/tokenizer_manager.py | 18 ++++++---- 2 files changed, 36 insertions(+), 15 deletions(-) diff --git a/python/sglang/srt/managers/generation_manager.py b/python/sglang/srt/managers/generation_manager.py index 8d35ac1b23c..b32e66c6046 100644 --- a/python/sglang/srt/managers/generation_manager.py +++ b/python/sglang/srt/managers/generation_manager.py @@ -7,14 +7,27 @@ import time from datetime import datetime from http import HTTPStatus -from typing import Optional, List, Any, Union, Dict, Callable, Tuple +from typing import Any, Callable, Dict, List, Optional, Tuple, Union import fastapi + from sglang.srt.configs.model_config import ModelConfig from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer -from sglang.srt.managers.image_processor import get_dummy_image_processor, get_image_processor -from sglang.srt.managers.io_struct import GenerateReqInput, EmbeddingReqInput, SessionParams, TokenizedGenerateReqInput, \ - TokenizedEmbeddingReqInput, BatchStrOut, BatchEmbeddingOut, BatchTokenIDOut, AbortReq +from sglang.srt.managers.image_processor import ( + get_dummy_image_processor, + get_image_processor, +) +from sglang.srt.managers.io_struct import ( + AbortReq, + BatchEmbeddingOut, + BatchStrOut, + BatchTokenIDOut, + EmbeddingReqInput, + GenerateReqInput, + SessionParams, + TokenizedEmbeddingReqInput, + TokenizedGenerateReqInput, +) from sglang.srt.metrics.collector import TokenizerMetricsCollector from sglang.srt.sampling.sampling_params import SamplingParams from sglang.srt.server_args import ServerArgs @@ -84,7 +97,9 @@ def _send_one_request( created_time: Optional[float] = None, ): event = asyncio.Event() - state = _ReqState([], False, event, obj, metric=_MetricReqState(created_time=created_time)) + state = _ReqState( + [], False, event, obj, metric=_MetricReqState(created_time=created_time) + ) self.rid_to_state[obj.rid] = state self.on_request(tokenized_obj) @@ -147,7 +162,9 @@ async def _handle_batch_request( # Send all requests for i in range(batch_size): tmp_obj = obj[i] - tokenized_obj = await self._generation_converter.tokenize_request(tmp_obj) + tokenized_obj = await self._generation_converter.tokenize_request( + tmp_obj + ) self._send_one_request(tmp_obj, tokenized_obj, created_time) generators.append(self._wait_one_response(tmp_obj, request)) rids.append(tmp_obj.rid) @@ -602,11 +619,11 @@ def __init__(self): self.dump_requests_threshold = 1000 self.dump_request_list: List[Tuple] = [] - def maybe_dump_requests(self, state: '_ReqState', out_dict: dict): + def maybe_dump_requests(self, state: "_ReqState", out_dict: dict): if self.dump_requests_folder and state.finished and state.obj.log_metrics: self._dump_requests(state, out_dict) - def _dump_requests(self, state: '_ReqState', out_dict: dict): + def _dump_requests(self, state: "_ReqState", out_dict: dict): self.dump_request_list.append( (state.obj, out_dict, state.metric.created_time, time.time()) ) diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py index 83ebaa25630..03527c60280 100644 --- a/python/sglang/srt/managers/tokenizer_manager.py +++ b/python/sglang/srt/managers/tokenizer_manager.py @@ -27,6 +27,7 @@ import zmq import zmq.asyncio from fastapi import BackgroundTasks + from sglang.srt.aio_rwlock import RWLock from sglang.srt.managers.generation_manager import GenerationManager from sglang.srt.managers.io_struct import ( @@ -57,10 +58,7 @@ UpdateWeightsFromTensorReqOutput, ) from sglang.srt.server_args import PortArgs, ServerArgs -from sglang.srt.utils import ( - get_zmq_socket, - kill_process_tree, -) +from sglang.srt.utils import get_zmq_socket, kill_process_tree from sglang.utils import TypeBasedDispatcher, get_exception_traceback asyncio.set_event_loop_policy(uvloop.EventLoopPolicy()) @@ -339,11 +337,17 @@ def configure_logging(self, obj: ConfigureLoggingReq): if obj.log_requests is not None: self._generation_manager.request_logger.log_requests = obj.log_requests if obj.log_requests_level is not None: - self._generation_manager.request_logger.log_requests_level = obj.log_requests_level + self._generation_manager.request_logger.log_requests_level = ( + obj.log_requests_level + ) if obj.dump_requests_folder is not None: - self._generation_manager.request_dumper.dump_requests_folder = obj.dump_requests_folder + self._generation_manager.request_dumper.dump_requests_folder = ( + obj.dump_requests_folder + ) if obj.dump_requests_threshold is not None: - self._generation_manager.request_dumper.dump_requests_threshold = obj.dump_requests_threshold + self._generation_manager.request_dumper.dump_requests_threshold = ( + obj.dump_requests_threshold + ) logging.info(f"Config logging: {obj=}") def create_abort_task(self, obj: GenerateReqInput): From b28ca30faca29941780c8756d48b1f195e80ad31 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Sat, 25 Jan 2025 10:01:11 +0800 Subject: [PATCH 042/136] empty file --- python/sglang/srt/orchestration/std/orchestrator.py | 0 1 file changed, 0 insertions(+), 0 deletions(-) create mode 100644 python/sglang/srt/orchestration/std/orchestrator.py diff --git a/python/sglang/srt/orchestration/std/orchestrator.py b/python/sglang/srt/orchestration/std/orchestrator.py new file mode 100644 index 00000000000..e69de29bb2d From e4f2393b3234ea82670b6c9d0feff26807df1880 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Sat, 25 Jan 2025 10:01:42 +0800 Subject: [PATCH 043/136] mv file --- .../sglang/srt/managers/tokenizer_manager.py | 491 ------------------ .../srt/orchestration/std/orchestrator.py | 491 ++++++++++++++++++ 2 files changed, 491 insertions(+), 491 deletions(-) delete mode 100644 python/sglang/srt/managers/tokenizer_manager.py diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py deleted file mode 100644 index 03527c60280..00000000000 --- a/python/sglang/srt/managers/tokenizer_manager.py +++ /dev/null @@ -1,491 +0,0 @@ -# Copyright 2023-2024 SGLang Team -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""TokenizerManager is a process that tokenizes the text.""" - -import asyncio -import logging -import os -import signal -import sys -import threading -import uuid -from typing import Awaitable, Generic, List, Optional, Tuple, TypeVar, Union - -import fastapi -import uvloop -import zmq -import zmq.asyncio -from fastapi import BackgroundTasks - -from sglang.srt.aio_rwlock import RWLock -from sglang.srt.managers.generation_manager import GenerationManager -from sglang.srt.managers.io_struct import ( - BatchEmbeddingOut, - BatchStrOut, - BatchTokenIDOut, - CloseSessionReqInput, - ConfigureLoggingReq, - EmbeddingReqInput, - FlushCacheReq, - GenerateReqInput, - GetWeightsByNameReqInput, - GetWeightsByNameReqOutput, - InitWeightsUpdateGroupReqInput, - InitWeightsUpdateGroupReqOutput, - OpenSessionReqInput, - OpenSessionReqOutput, - ProfileReq, - ReleaseMemoryOccupationReqInput, - ReleaseMemoryOccupationReqOutput, - ResumeMemoryOccupationReqInput, - ResumeMemoryOccupationReqOutput, - UpdateWeightFromDiskReqInput, - UpdateWeightFromDiskReqOutput, - UpdateWeightsFromDistributedReqInput, - UpdateWeightsFromDistributedReqOutput, - UpdateWeightsFromTensorReqInput, - UpdateWeightsFromTensorReqOutput, -) -from sglang.srt.server_args import PortArgs, ServerArgs -from sglang.srt.utils import get_zmq_socket, kill_process_tree -from sglang.utils import TypeBasedDispatcher, get_exception_traceback - -asyncio.set_event_loop_policy(uvloop.EventLoopPolicy()) - -logger = logging.getLogger(__name__) - - -class TokenizerManager: - """TokenizerManager is a process that tokenizes the text.""" - - def __init__( - self, - server_args: ServerArgs, - port_args: PortArgs, - ): - # Parse args - - self.server_args = server_args - self.enable_metrics = server_args.enable_metrics - - # Init inter-process communication - context = zmq.asyncio.Context(2) - self.recv_from_detokenizer = get_zmq_socket( - context, zmq.PULL, port_args.tokenizer_ipc_name, True - ) - self.send_to_scheduler = get_zmq_socket( - context, zmq.PUSH, port_args.scheduler_input_ipc_name, True - ) - - # Read model args - self.model_path = server_args.model_path - self.served_model_name = server_args.served_model_name - - # Store states - self.no_create_loop = False - - # The event to notify the weight sync is finished. - self.model_update_lock = RWLock() - self.model_update_result: Optional[Awaitable[UpdateWeightFromDiskReqOutput]] = ( - None - ) - self.asyncio_tasks = set() - - # For session info - self.session_futures = {} # session_id -> asyncio event - - self._generation_manager = GenerationManager( - server_args=server_args, - on_request=self.send_to_scheduler.send_pyobj, - ) - - # Others - self.gracefully_exit = False - self.init_weights_update_group_communicator = _Communicator( - self.send_to_scheduler, server_args.dp_size - ) - self.update_weights_from_distributed_communicator = _Communicator( - self.send_to_scheduler, server_args.dp_size - ) - self.update_weights_from_tensor_communicator = _Communicator( - self.send_to_scheduler, server_args.dp_size - ) - self.get_weights_by_name_communicator = _Communicator( - self.send_to_scheduler, server_args.dp_size - ) - self.release_memory_occupation_communicator = _Communicator( - self.send_to_scheduler, server_args.dp_size - ) - self.resume_memory_occupation_communicator = _Communicator( - self.send_to_scheduler, server_args.dp_size - ) - # Set after scheduler is initialized - self.max_req_input_len = None - - self._result_dispatcher = TypeBasedDispatcher( - [ - ( - (BatchStrOut, BatchEmbeddingOut, BatchTokenIDOut), - self._generation_manager.handle_batch_output, - ), - (OpenSessionReqOutput, self._handle_open_session_req_output), - ( - UpdateWeightFromDiskReqOutput, - self._handle_update_weights_from_disk_req_output, - ), - ( - InitWeightsUpdateGroupReqOutput, - self.init_weights_update_group_communicator.handle_recv, - ), - ( - UpdateWeightsFromDistributedReqOutput, - self.update_weights_from_distributed_communicator.handle_recv, - ), - ( - UpdateWeightsFromTensorReqOutput, - self.update_weights_from_tensor_communicator.handle_recv, - ), - ( - GetWeightsByNameReqOutput, - self.get_weights_by_name_communicator.handle_recv, - ), - ( - ReleaseMemoryOccupationReqOutput, - self.release_memory_occupation_communicator.handle_recv, - ), - ( - ResumeMemoryOccupationReqOutput, - self.resume_memory_occupation_communicator.handle_recv, - ), - ] - ) - - async def generate_request( - self, - obj: Union[GenerateReqInput, EmbeddingReqInput], - request: Optional[fastapi.Request] = None, - ): - self.auto_create_handle_loop() - async with self.model_update_lock.reader_lock: - async for value in self._generation_manager.generate(obj, request): - yield value - - def flush_cache(self): - req = FlushCacheReq() - self.send_to_scheduler.send_pyobj(req) - - def abort_request(self, rid: str): - self._generation_manager.abort_request(rid) - - def start_profile(self): - req = ProfileReq.START_PROFILE - self.send_to_scheduler.send_pyobj(req) - - def stop_profile(self): - req = ProfileReq.STOP_PROFILE - self.send_to_scheduler.send_pyobj(req) - - async def update_weights_from_disk( - self, - obj: UpdateWeightFromDiskReqInput, - request: Optional[fastapi.Request] = None, - ) -> Tuple[bool, str]: - self.auto_create_handle_loop() - - # default the load format to the server_args - if obj.load_format is None: - obj.load_format = self.server_args.load_format - logger.info("Start update_weights. Load format=%s", obj.load_format) - - if True: - # Hold the lock if it is not async. This means that weight sync - # cannot run while requests are in progress. - async with self.model_update_lock.writer_lock: - return await self._wait_for_model_update_from_disk(obj) - - async def _wait_for_model_update_from_disk( - self, obj: UpdateWeightFromDiskReqInput - ) -> Tuple[bool, str]: - self.send_to_scheduler.send_pyobj(obj) - self.model_update_result = asyncio.Future() - if self.server_args.dp_size == 1: - result = await self.model_update_result - if result.success: - self.served_model_name = obj.model_path - self.server_args.model_path = obj.model_path - self.server_args.load_format = obj.load_format - self.model_path = obj.model_path - return result.success, result.message - else: # self.server_args.dp_size > 1 - self.model_update_tmp = [] - result = await self.model_update_result - - all_success = all([r.success for r in result]) - if all_success is True: - self.server_args.model_path = obj.model_path - self.server_args.load_format = obj.load_format - self.model_path = obj.model_path - all_message = [r.message for r in result] - all_message = " | ".join(all_message) - return all_success, all_message - - async def init_weights_update_group( - self, - obj: InitWeightsUpdateGroupReqInput, - request: Optional[fastapi.Request] = None, - ) -> Tuple[bool, str]: - self.auto_create_handle_loop() - assert ( - self.server_args.dp_size == 1 - ), "dp_size must be 1 for init parameter update group" - result = (await self.init_weights_update_group_communicator(obj))[0] - return result.success, result.message - - async def update_weights_from_distributed( - self, - obj: UpdateWeightsFromDistributedReqInput, - request: Optional[fastapi.Request] = None, - ) -> Tuple[bool, str]: - self.auto_create_handle_loop() - assert ( - self.server_args.dp_size == 1 - ), "dp_size must be for update weights from distributed" - - # This means that weight sync - # cannot run while requests are in progress. - async with self.model_update_lock.writer_lock: - result = (await self.update_weights_from_distributed_communicator(obj))[0] - return result.success, result.message - - async def update_weights_from_tensor( - self, - obj: UpdateWeightsFromTensorReqInput, - request: Optional[fastapi.Request] = None, - ) -> Tuple[bool, str]: - self.auto_create_handle_loop() - assert ( - self.server_args.dp_size == 1 - ), "dp_size must be for update weights from distributed" - - # This means that weight sync - # cannot run while requests are in progress. - async with self.model_update_lock.writer_lock: - result = (await self.update_weights_from_tensor_communicator(obj))[0] - return result.success, result.message - - async def get_weights_by_name( - self, obj: GetWeightsByNameReqInput, request: Optional[fastapi.Request] = None - ): - self.auto_create_handle_loop() - results = await self.get_weights_by_name_communicator(obj) - all_parameters = [r.parameter for r in results] - if self.server_args.dp_size == 1: - return all_parameters[0] - else: - return all_parameters - - async def release_memory_occupation( - self, - obj: ReleaseMemoryOccupationReqInput, - request: Optional[fastapi.Request] = None, - ): - self.auto_create_handle_loop() - await self.release_memory_occupation_communicator(obj) - - async def resume_memory_occupation( - self, - obj: ResumeMemoryOccupationReqInput, - request: Optional[fastapi.Request] = None, - ): - self.auto_create_handle_loop() - await self.resume_memory_occupation_communicator(obj) - - async def open_session( - self, obj: OpenSessionReqInput, request: Optional[fastapi.Request] = None - ): - self.auto_create_handle_loop() - - if obj.session_id is None: - obj.session_id = uuid.uuid4().hex - elif obj.session_id in self.session_futures: - return None - - self.send_to_scheduler.send_pyobj(obj) - - self.session_futures[obj.session_id] = asyncio.Future() - session_id = await self.session_futures[obj.session_id] - del self.session_futures[obj.session_id] - return session_id - - async def close_session( - self, obj: CloseSessionReqInput, request: Optional[fastapi.Request] = None - ): - await self.send_to_scheduler.send_pyobj(obj) - - def configure_logging(self, obj: ConfigureLoggingReq): - if obj.log_requests is not None: - self._generation_manager.request_logger.log_requests = obj.log_requests - if obj.log_requests_level is not None: - self._generation_manager.request_logger.log_requests_level = ( - obj.log_requests_level - ) - if obj.dump_requests_folder is not None: - self._generation_manager.request_dumper.dump_requests_folder = ( - obj.dump_requests_folder - ) - if obj.dump_requests_threshold is not None: - self._generation_manager.request_dumper.dump_requests_threshold = ( - obj.dump_requests_threshold - ) - logging.info(f"Config logging: {obj=}") - - def create_abort_task(self, obj: GenerateReqInput): - # Abort the request if the client is disconnected. - async def abort_request(): - await asyncio.sleep(1) - if obj.is_single: - self.abort_request(obj.rid) - else: - for rid in obj.rid: - self.abort_request(rid) - - background_tasks = BackgroundTasks() - background_tasks.add_task(abort_request) - return background_tasks - - def auto_create_handle_loop(self): - if self.no_create_loop: - return - - self.no_create_loop = True - loop = asyncio.get_event_loop() - self.asyncio_tasks.add( - loop.create_task(print_exception_wrapper(self.handle_loop)) - ) - - # We cannot add signal handler when the tokenizer manager is not in - # the main thread due to the CPython limitation. - if threading.current_thread() is threading.main_thread(): - signal_handler = SignalHandler(self) - loop.add_signal_handler(signal.SIGTERM, signal_handler.signal_handler) - else: - logger.warning( - "Signal handler is not added because the tokenizer manager is " - "not in the main thread. This disables graceful shutdown of the " - "tokenizer manager when SIGTERM is received." - ) - self.asyncio_tasks.add( - loop.create_task(print_exception_wrapper(self.sigterm_watchdog)) - ) - - async def sigterm_watchdog(self): - while not self.gracefully_exit: - await asyncio.sleep(5) - - # Drain requests - while True: - remain_num_req = len(self._generation_manager.rid_to_state) - logger.info( - f"Gracefully exiting... remaining number of requests {remain_num_req}" - ) - if remain_num_req > 0: - await asyncio.sleep(5) - else: - break - - kill_process_tree(os.getpid(), include_parent=True) - sys.exit(0) - - async def handle_loop(self): - """The event loop that handles requests""" - - while True: - recv_obj = await self.recv_from_detokenizer.recv_pyobj() - self._result_dispatcher(recv_obj) - - def _handle_open_session_req_output(self, recv_obj): - self.session_futures[recv_obj.session_id].set_result( - recv_obj.session_id if recv_obj.success else None - ) - - def _handle_update_weights_from_disk_req_output(self, recv_obj): - if self.server_args.dp_size == 1: - self.model_update_result.set_result(recv_obj) - else: # self.server_args.dp_size > 1 - self.model_update_tmp.append(recv_obj) - # set future if the all results are recevied - if len(self.model_update_tmp) == self.server_args.dp_size: - self.model_update_result.set_result(self.model_update_tmp) - - @property - def is_generation(self): - return self._generation_manager.model_config.is_generation - - @property - def tokenizer(self): - return self._generation_manager.tokenizer - - @property - def image_token_id(self): - return self._generation_manager.model_config.image_token_id - - -async def print_exception_wrapper(func): - """ - Sometimes an asyncio function does not print exception. - We do another wrapper to handle the exception. - """ - try: - await func() - except Exception: - traceback = get_exception_traceback() - logger.error(f"TokenizerManager hit an exception: {traceback}") - kill_process_tree(os.getpid(), include_parent=True) - sys.exit(1) - - -class SignalHandler: - def __init__(self, tokenizer_manager): - self.tokenizer_manager = tokenizer_manager - - def signal_handler(self, signum=None, frame=None): - logger.warning( - f"SIGTERM received. {signum=} {frame=}. Draining requests and shutting down..." - ) - self.tokenizer_manager.gracefully_exit = True - - -T = TypeVar("T") - - -class _Communicator(Generic[T]): - def __init__(self, sender, fan_out: int): - self._sender = sender - self._fan_out = fan_out - self._result_future: Optional[asyncio.Future] = None - self._result_values: Optional[List[T]] = None - - async def __call__(self, obj): - self._sender.send_pyobj(obj) - self._result_future = asyncio.Future() - self._result_values = [] - await self._result_future - result_values = self._result_values - self._result_future = self._result_values = None - return result_values - - def handle_recv(self, recv_obj: T): - self._result_values.append(recv_obj) - if len(self._result_values) == self._fan_out: - self._result_future.set_result(None) diff --git a/python/sglang/srt/orchestration/std/orchestrator.py b/python/sglang/srt/orchestration/std/orchestrator.py index e69de29bb2d..03527c60280 100644 --- a/python/sglang/srt/orchestration/std/orchestrator.py +++ b/python/sglang/srt/orchestration/std/orchestrator.py @@ -0,0 +1,491 @@ +# Copyright 2023-2024 SGLang Team +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""TokenizerManager is a process that tokenizes the text.""" + +import asyncio +import logging +import os +import signal +import sys +import threading +import uuid +from typing import Awaitable, Generic, List, Optional, Tuple, TypeVar, Union + +import fastapi +import uvloop +import zmq +import zmq.asyncio +from fastapi import BackgroundTasks + +from sglang.srt.aio_rwlock import RWLock +from sglang.srt.managers.generation_manager import GenerationManager +from sglang.srt.managers.io_struct import ( + BatchEmbeddingOut, + BatchStrOut, + BatchTokenIDOut, + CloseSessionReqInput, + ConfigureLoggingReq, + EmbeddingReqInput, + FlushCacheReq, + GenerateReqInput, + GetWeightsByNameReqInput, + GetWeightsByNameReqOutput, + InitWeightsUpdateGroupReqInput, + InitWeightsUpdateGroupReqOutput, + OpenSessionReqInput, + OpenSessionReqOutput, + ProfileReq, + ReleaseMemoryOccupationReqInput, + ReleaseMemoryOccupationReqOutput, + ResumeMemoryOccupationReqInput, + ResumeMemoryOccupationReqOutput, + UpdateWeightFromDiskReqInput, + UpdateWeightFromDiskReqOutput, + UpdateWeightsFromDistributedReqInput, + UpdateWeightsFromDistributedReqOutput, + UpdateWeightsFromTensorReqInput, + UpdateWeightsFromTensorReqOutput, +) +from sglang.srt.server_args import PortArgs, ServerArgs +from sglang.srt.utils import get_zmq_socket, kill_process_tree +from sglang.utils import TypeBasedDispatcher, get_exception_traceback + +asyncio.set_event_loop_policy(uvloop.EventLoopPolicy()) + +logger = logging.getLogger(__name__) + + +class TokenizerManager: + """TokenizerManager is a process that tokenizes the text.""" + + def __init__( + self, + server_args: ServerArgs, + port_args: PortArgs, + ): + # Parse args + + self.server_args = server_args + self.enable_metrics = server_args.enable_metrics + + # Init inter-process communication + context = zmq.asyncio.Context(2) + self.recv_from_detokenizer = get_zmq_socket( + context, zmq.PULL, port_args.tokenizer_ipc_name, True + ) + self.send_to_scheduler = get_zmq_socket( + context, zmq.PUSH, port_args.scheduler_input_ipc_name, True + ) + + # Read model args + self.model_path = server_args.model_path + self.served_model_name = server_args.served_model_name + + # Store states + self.no_create_loop = False + + # The event to notify the weight sync is finished. + self.model_update_lock = RWLock() + self.model_update_result: Optional[Awaitable[UpdateWeightFromDiskReqOutput]] = ( + None + ) + self.asyncio_tasks = set() + + # For session info + self.session_futures = {} # session_id -> asyncio event + + self._generation_manager = GenerationManager( + server_args=server_args, + on_request=self.send_to_scheduler.send_pyobj, + ) + + # Others + self.gracefully_exit = False + self.init_weights_update_group_communicator = _Communicator( + self.send_to_scheduler, server_args.dp_size + ) + self.update_weights_from_distributed_communicator = _Communicator( + self.send_to_scheduler, server_args.dp_size + ) + self.update_weights_from_tensor_communicator = _Communicator( + self.send_to_scheduler, server_args.dp_size + ) + self.get_weights_by_name_communicator = _Communicator( + self.send_to_scheduler, server_args.dp_size + ) + self.release_memory_occupation_communicator = _Communicator( + self.send_to_scheduler, server_args.dp_size + ) + self.resume_memory_occupation_communicator = _Communicator( + self.send_to_scheduler, server_args.dp_size + ) + # Set after scheduler is initialized + self.max_req_input_len = None + + self._result_dispatcher = TypeBasedDispatcher( + [ + ( + (BatchStrOut, BatchEmbeddingOut, BatchTokenIDOut), + self._generation_manager.handle_batch_output, + ), + (OpenSessionReqOutput, self._handle_open_session_req_output), + ( + UpdateWeightFromDiskReqOutput, + self._handle_update_weights_from_disk_req_output, + ), + ( + InitWeightsUpdateGroupReqOutput, + self.init_weights_update_group_communicator.handle_recv, + ), + ( + UpdateWeightsFromDistributedReqOutput, + self.update_weights_from_distributed_communicator.handle_recv, + ), + ( + UpdateWeightsFromTensorReqOutput, + self.update_weights_from_tensor_communicator.handle_recv, + ), + ( + GetWeightsByNameReqOutput, + self.get_weights_by_name_communicator.handle_recv, + ), + ( + ReleaseMemoryOccupationReqOutput, + self.release_memory_occupation_communicator.handle_recv, + ), + ( + ResumeMemoryOccupationReqOutput, + self.resume_memory_occupation_communicator.handle_recv, + ), + ] + ) + + async def generate_request( + self, + obj: Union[GenerateReqInput, EmbeddingReqInput], + request: Optional[fastapi.Request] = None, + ): + self.auto_create_handle_loop() + async with self.model_update_lock.reader_lock: + async for value in self._generation_manager.generate(obj, request): + yield value + + def flush_cache(self): + req = FlushCacheReq() + self.send_to_scheduler.send_pyobj(req) + + def abort_request(self, rid: str): + self._generation_manager.abort_request(rid) + + def start_profile(self): + req = ProfileReq.START_PROFILE + self.send_to_scheduler.send_pyobj(req) + + def stop_profile(self): + req = ProfileReq.STOP_PROFILE + self.send_to_scheduler.send_pyobj(req) + + async def update_weights_from_disk( + self, + obj: UpdateWeightFromDiskReqInput, + request: Optional[fastapi.Request] = None, + ) -> Tuple[bool, str]: + self.auto_create_handle_loop() + + # default the load format to the server_args + if obj.load_format is None: + obj.load_format = self.server_args.load_format + logger.info("Start update_weights. Load format=%s", obj.load_format) + + if True: + # Hold the lock if it is not async. This means that weight sync + # cannot run while requests are in progress. + async with self.model_update_lock.writer_lock: + return await self._wait_for_model_update_from_disk(obj) + + async def _wait_for_model_update_from_disk( + self, obj: UpdateWeightFromDiskReqInput + ) -> Tuple[bool, str]: + self.send_to_scheduler.send_pyobj(obj) + self.model_update_result = asyncio.Future() + if self.server_args.dp_size == 1: + result = await self.model_update_result + if result.success: + self.served_model_name = obj.model_path + self.server_args.model_path = obj.model_path + self.server_args.load_format = obj.load_format + self.model_path = obj.model_path + return result.success, result.message + else: # self.server_args.dp_size > 1 + self.model_update_tmp = [] + result = await self.model_update_result + + all_success = all([r.success for r in result]) + if all_success is True: + self.server_args.model_path = obj.model_path + self.server_args.load_format = obj.load_format + self.model_path = obj.model_path + all_message = [r.message for r in result] + all_message = " | ".join(all_message) + return all_success, all_message + + async def init_weights_update_group( + self, + obj: InitWeightsUpdateGroupReqInput, + request: Optional[fastapi.Request] = None, + ) -> Tuple[bool, str]: + self.auto_create_handle_loop() + assert ( + self.server_args.dp_size == 1 + ), "dp_size must be 1 for init parameter update group" + result = (await self.init_weights_update_group_communicator(obj))[0] + return result.success, result.message + + async def update_weights_from_distributed( + self, + obj: UpdateWeightsFromDistributedReqInput, + request: Optional[fastapi.Request] = None, + ) -> Tuple[bool, str]: + self.auto_create_handle_loop() + assert ( + self.server_args.dp_size == 1 + ), "dp_size must be for update weights from distributed" + + # This means that weight sync + # cannot run while requests are in progress. + async with self.model_update_lock.writer_lock: + result = (await self.update_weights_from_distributed_communicator(obj))[0] + return result.success, result.message + + async def update_weights_from_tensor( + self, + obj: UpdateWeightsFromTensorReqInput, + request: Optional[fastapi.Request] = None, + ) -> Tuple[bool, str]: + self.auto_create_handle_loop() + assert ( + self.server_args.dp_size == 1 + ), "dp_size must be for update weights from distributed" + + # This means that weight sync + # cannot run while requests are in progress. + async with self.model_update_lock.writer_lock: + result = (await self.update_weights_from_tensor_communicator(obj))[0] + return result.success, result.message + + async def get_weights_by_name( + self, obj: GetWeightsByNameReqInput, request: Optional[fastapi.Request] = None + ): + self.auto_create_handle_loop() + results = await self.get_weights_by_name_communicator(obj) + all_parameters = [r.parameter for r in results] + if self.server_args.dp_size == 1: + return all_parameters[0] + else: + return all_parameters + + async def release_memory_occupation( + self, + obj: ReleaseMemoryOccupationReqInput, + request: Optional[fastapi.Request] = None, + ): + self.auto_create_handle_loop() + await self.release_memory_occupation_communicator(obj) + + async def resume_memory_occupation( + self, + obj: ResumeMemoryOccupationReqInput, + request: Optional[fastapi.Request] = None, + ): + self.auto_create_handle_loop() + await self.resume_memory_occupation_communicator(obj) + + async def open_session( + self, obj: OpenSessionReqInput, request: Optional[fastapi.Request] = None + ): + self.auto_create_handle_loop() + + if obj.session_id is None: + obj.session_id = uuid.uuid4().hex + elif obj.session_id in self.session_futures: + return None + + self.send_to_scheduler.send_pyobj(obj) + + self.session_futures[obj.session_id] = asyncio.Future() + session_id = await self.session_futures[obj.session_id] + del self.session_futures[obj.session_id] + return session_id + + async def close_session( + self, obj: CloseSessionReqInput, request: Optional[fastapi.Request] = None + ): + await self.send_to_scheduler.send_pyobj(obj) + + def configure_logging(self, obj: ConfigureLoggingReq): + if obj.log_requests is not None: + self._generation_manager.request_logger.log_requests = obj.log_requests + if obj.log_requests_level is not None: + self._generation_manager.request_logger.log_requests_level = ( + obj.log_requests_level + ) + if obj.dump_requests_folder is not None: + self._generation_manager.request_dumper.dump_requests_folder = ( + obj.dump_requests_folder + ) + if obj.dump_requests_threshold is not None: + self._generation_manager.request_dumper.dump_requests_threshold = ( + obj.dump_requests_threshold + ) + logging.info(f"Config logging: {obj=}") + + def create_abort_task(self, obj: GenerateReqInput): + # Abort the request if the client is disconnected. + async def abort_request(): + await asyncio.sleep(1) + if obj.is_single: + self.abort_request(obj.rid) + else: + for rid in obj.rid: + self.abort_request(rid) + + background_tasks = BackgroundTasks() + background_tasks.add_task(abort_request) + return background_tasks + + def auto_create_handle_loop(self): + if self.no_create_loop: + return + + self.no_create_loop = True + loop = asyncio.get_event_loop() + self.asyncio_tasks.add( + loop.create_task(print_exception_wrapper(self.handle_loop)) + ) + + # We cannot add signal handler when the tokenizer manager is not in + # the main thread due to the CPython limitation. + if threading.current_thread() is threading.main_thread(): + signal_handler = SignalHandler(self) + loop.add_signal_handler(signal.SIGTERM, signal_handler.signal_handler) + else: + logger.warning( + "Signal handler is not added because the tokenizer manager is " + "not in the main thread. This disables graceful shutdown of the " + "tokenizer manager when SIGTERM is received." + ) + self.asyncio_tasks.add( + loop.create_task(print_exception_wrapper(self.sigterm_watchdog)) + ) + + async def sigterm_watchdog(self): + while not self.gracefully_exit: + await asyncio.sleep(5) + + # Drain requests + while True: + remain_num_req = len(self._generation_manager.rid_to_state) + logger.info( + f"Gracefully exiting... remaining number of requests {remain_num_req}" + ) + if remain_num_req > 0: + await asyncio.sleep(5) + else: + break + + kill_process_tree(os.getpid(), include_parent=True) + sys.exit(0) + + async def handle_loop(self): + """The event loop that handles requests""" + + while True: + recv_obj = await self.recv_from_detokenizer.recv_pyobj() + self._result_dispatcher(recv_obj) + + def _handle_open_session_req_output(self, recv_obj): + self.session_futures[recv_obj.session_id].set_result( + recv_obj.session_id if recv_obj.success else None + ) + + def _handle_update_weights_from_disk_req_output(self, recv_obj): + if self.server_args.dp_size == 1: + self.model_update_result.set_result(recv_obj) + else: # self.server_args.dp_size > 1 + self.model_update_tmp.append(recv_obj) + # set future if the all results are recevied + if len(self.model_update_tmp) == self.server_args.dp_size: + self.model_update_result.set_result(self.model_update_tmp) + + @property + def is_generation(self): + return self._generation_manager.model_config.is_generation + + @property + def tokenizer(self): + return self._generation_manager.tokenizer + + @property + def image_token_id(self): + return self._generation_manager.model_config.image_token_id + + +async def print_exception_wrapper(func): + """ + Sometimes an asyncio function does not print exception. + We do another wrapper to handle the exception. + """ + try: + await func() + except Exception: + traceback = get_exception_traceback() + logger.error(f"TokenizerManager hit an exception: {traceback}") + kill_process_tree(os.getpid(), include_parent=True) + sys.exit(1) + + +class SignalHandler: + def __init__(self, tokenizer_manager): + self.tokenizer_manager = tokenizer_manager + + def signal_handler(self, signum=None, frame=None): + logger.warning( + f"SIGTERM received. {signum=} {frame=}. Draining requests and shutting down..." + ) + self.tokenizer_manager.gracefully_exit = True + + +T = TypeVar("T") + + +class _Communicator(Generic[T]): + def __init__(self, sender, fan_out: int): + self._sender = sender + self._fan_out = fan_out + self._result_future: Optional[asyncio.Future] = None + self._result_values: Optional[List[T]] = None + + async def __call__(self, obj): + self._sender.send_pyobj(obj) + self._result_future = asyncio.Future() + self._result_values = [] + await self._result_future + result_values = self._result_values + self._result_future = self._result_values = None + return result_values + + def handle_recv(self, recv_obj: T): + self._result_values.append(recv_obj) + if len(self._result_values) == self._fan_out: + self._result_future.set_result(None) From 15049c2bf047a6963626a8be0265ce630f6b8ecf Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Sat, 25 Jan 2025 10:02:25 +0800 Subject: [PATCH 044/136] rename class --- python/sglang/srt/orchestration/std/orchestrator.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/python/sglang/srt/orchestration/std/orchestrator.py b/python/sglang/srt/orchestration/std/orchestrator.py index 03527c60280..9e57150aff9 100644 --- a/python/sglang/srt/orchestration/std/orchestrator.py +++ b/python/sglang/srt/orchestration/std/orchestrator.py @@ -11,7 +11,6 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""TokenizerManager is a process that tokenizes the text.""" import asyncio import logging @@ -27,7 +26,6 @@ import zmq import zmq.asyncio from fastapi import BackgroundTasks - from sglang.srt.aio_rwlock import RWLock from sglang.srt.managers.generation_manager import GenerationManager from sglang.srt.managers.io_struct import ( @@ -66,8 +64,8 @@ logger = logging.getLogger(__name__) -class TokenizerManager: - """TokenizerManager is a process that tokenizes the text.""" +class StdOrchestrator: + """StdOrchestrator is the primary entrypoint of orchestration.std package""" def __init__( self, @@ -450,7 +448,7 @@ async def print_exception_wrapper(func): await func() except Exception: traceback = get_exception_traceback() - logger.error(f"TokenizerManager hit an exception: {traceback}") + logger.error(f"StdOrchestrator hit an exception: {traceback}") kill_process_tree(os.getpid(), include_parent=True) sys.exit(1) From b6dcf819904244a53b39067c9449286846ae1951 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Sat, 25 Jan 2025 10:02:57 +0800 Subject: [PATCH 045/136] fix import --- python/sglang/srt/entrypoints/engine.py | 2 +- python/sglang/srt/entrypoints/http_server.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/python/sglang/srt/entrypoints/engine.py b/python/sglang/srt/entrypoints/engine.py index 310e92c23d9..e0e2e902bfd 100644 --- a/python/sglang/srt/entrypoints/engine.py +++ b/python/sglang/srt/entrypoints/engine.py @@ -48,7 +48,7 @@ UpdateWeightsFromTensorReqInput, ) from sglang.srt.managers.scheduler import run_scheduler_process -from sglang.srt.managers.tokenizer_manager import TokenizerManager +from sglang.srt.orchestration.std.orchestrator import StdOrchestrator from sglang.srt.openai_api.adapter import load_chat_template_for_openai_api from sglang.srt.server_args import PortArgs, ServerArgs from sglang.srt.torch_memory_saver_adapter import TorchMemorySaverAdapter diff --git a/python/sglang/srt/entrypoints/http_server.py b/python/sglang/srt/entrypoints/http_server.py index 0ebce1a85d5..6a98e59d653 100644 --- a/python/sglang/srt/entrypoints/http_server.py +++ b/python/sglang/srt/entrypoints/http_server.py @@ -52,7 +52,7 @@ UpdateWeightFromDiskReqInput, UpdateWeightsFromDistributedReqInput, ) -from sglang.srt.managers.tokenizer_manager import TokenizerManager +from sglang.srt.orchestration.std.orchestrator import StdOrchestrator from sglang.srt.metrics.func_timer import enable_func_timer from sglang.srt.openai_api.adapter import ( v1_batches, From 4572406ec6fa8388730c1310a91e52d5bc6fe65b Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Sat, 25 Jan 2025 10:03:27 +0800 Subject: [PATCH 046/136] rename class --- python/sglang/srt/entrypoints/engine.py | 10 +++++----- python/sglang/srt/entrypoints/http_server.py | 6 +++--- python/sglang/srt/managers/image_processor.py | 4 ++-- python/sglang/srt/managers/io_struct.py | 2 +- python/sglang/srt/managers/scheduler.py | 2 +- python/sglang/srt/server_args.py | 2 +- 6 files changed, 13 insertions(+), 13 deletions(-) diff --git a/python/sglang/srt/entrypoints/engine.py b/python/sglang/srt/entrypoints/engine.py index e0e2e902bfd..6c5215e399f 100644 --- a/python/sglang/srt/entrypoints/engine.py +++ b/python/sglang/srt/entrypoints/engine.py @@ -73,12 +73,12 @@ class Engine: The entry point to the inference engine. - The engine consists of three components: - 1. TokenizerManager: Tokenizes the requests and sends them to the scheduler. + 1. StdOrchestrator: Tokenizes the requests and sends them to the scheduler. 2. Scheduler (subprocess): Receives requests from the Tokenizer Manager, schedules batches, forwards them, and sends the output tokens to the Detokenizer Manager. 3. DetokenizerManager (subprocess): Detokenizes the output tokens and sends the result back to the Tokenizer Manager. Note: - 1. The HTTP server, Engine, and TokenizerManager both run in the main process. + 1. The HTTP server, Engine, and StdOrchestrator both run in the main process. 2. Inter-process communication is done through ICP (each process uses a different port) via the ZMQ library. """ @@ -337,9 +337,9 @@ def sigquit_handler(signum, frame): mp.set_start_method("spawn", force=True) -def _launch_subprocesses(server_args: ServerArgs) -> Tuple[TokenizerManager, Dict]: +def _launch_subprocesses(server_args: ServerArgs) -> Tuple[StdOrchestrator, Dict]: """ - Launch the TokenizerManager in the main process, the Scheduler in a subprocess, and the DetokenizerManager in another subprocess. + Launch the StdOrchestrator in the main process, the Scheduler in a subprocess, and the DetokenizerManager in another subprocess. """ # Configure global environment configure_logger(server_args) @@ -420,7 +420,7 @@ def _launch_subprocesses(server_args: ServerArgs) -> Tuple[TokenizerManager, Dic detoken_proc.start() # Launch tokenizer process - tokenizer_manager = TokenizerManager(server_args, port_args) + tokenizer_manager = StdOrchestrator(server_args, port_args) if server_args.chat_template: load_chat_template_for_openai_api(tokenizer_manager, server_args.chat_template) diff --git a/python/sglang/srt/entrypoints/http_server.py b/python/sglang/srt/entrypoints/http_server.py index 6a98e59d653..36c3ed3092a 100644 --- a/python/sglang/srt/entrypoints/http_server.py +++ b/python/sglang/srt/entrypoints/http_server.py @@ -95,7 +95,7 @@ # Store global states @dataclasses.dataclass class _GlobalState: - tokenizer_manager: TokenizerManager + tokenizer_manager: StdOrchestrator scheduler_info: Dict @@ -456,12 +456,12 @@ def launch_server( - HTTP server: A FastAPI server that routes requests to the engine. - The engine consists of three components: - 1. TokenizerManager: Tokenizes the requests and sends them to the scheduler. + 1. StdOrchestrator: Tokenizes the requests and sends them to the scheduler. 2. Scheduler (subprocess): Receives requests from the Tokenizer Manager, schedules batches, forwards them, and sends the output tokens to the Detokenizer Manager. 3. DetokenizerManager (subprocess): Detokenizes the output tokens and sends the result back to the Tokenizer Manager. Note: - 1. The HTTP server, Engine, and TokenizerManager both run in the main process. + 1. The HTTP server, Engine, and StdOrchestrator both run in the main process. 2. Inter-process communication is done through ICP (each process uses a different port) via the ZMQ library. """ tokenizer_manager, scheduler_info = _launch_subprocesses(server_args=server_args) diff --git a/python/sglang/srt/managers/image_processor.py b/python/sglang/srt/managers/image_processor.py index c8ebbed783a..0934db00250 100644 --- a/python/sglang/srt/managers/image_processor.py +++ b/python/sglang/srt/managers/image_processor.py @@ -109,7 +109,7 @@ def _process_single_image_task( return pixel_values, image_hash, image.size except Exception: - logger.error("Exception in TokenizerManager:\n" + get_exception_traceback()) + logger.error("Exception in StdOrchestrator:\n" + get_exception_traceback()) async def _process_single_image( self, image_data: Union[bytes, str], aspect_ratio: str, grid_pinpoints: str @@ -424,7 +424,7 @@ def _process_single_image_task( return pixel_values, image_hash, image.size, image_grid_thws except Exception: - logger.error("Exception in TokenizerManager:\n" + get_exception_traceback()) + logger.error("Exception in StdOrchestrator:\n" + get_exception_traceback()) async def _process_single_image(self, image_data: Union[bytes, str]): if self.executor is not None: diff --git a/python/sglang/srt/managers/io_struct.py b/python/sglang/srt/managers/io_struct.py index eee9b6722d4..f4589ad1f83 100644 --- a/python/sglang/srt/managers/io_struct.py +++ b/python/sglang/srt/managers/io_struct.py @@ -13,7 +13,7 @@ # ============================================================================== """ The definition of objects transfered between different -processes (TokenizerManager, DetokenizerManager, Controller). +processes (StdOrchestrator, DetokenizerManager, Controller). """ import uuid diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index 85bd1c2a4ad..80d63be7b65 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -172,7 +172,7 @@ def __init__( ) if server_args.skip_tokenizer_init: - # Directly send to the TokenizerManager + # Directly send to the StdOrchestrator self.send_to_detokenizer = get_zmq_socket( context, zmq.PUSH, port_args.tokenizer_ipc_name, False ) diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index 330c3813288..26947d0de18 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -985,7 +985,7 @@ def init_new(server_args, dp_rank: Optional[int] = None) -> "PortArgs": if dp_rank is None: scheduler_input_port = ( port_base + 2 - ) # TokenizerManager to DataParallelController + ) # StdOrchestrator to DataParallelController else: scheduler_input_port = port_base + 2 + 1 + dp_rank From 6f01739e3b9ad18f4713823a53754217638e4c45 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Sat, 25 Jan 2025 10:04:27 +0800 Subject: [PATCH 047/136] rename tokenizer_manager --- python/sglang/srt/entrypoints/engine.py | 36 ++++----- python/sglang/srt/entrypoints/http_server.py | 68 ++++++++--------- python/sglang/srt/openai_api/adapter.py | 76 +++++++++---------- .../srt/orchestration/std/orchestrator.py | 6 +- 4 files changed, 93 insertions(+), 93 deletions(-) diff --git a/python/sglang/srt/entrypoints/engine.py b/python/sglang/srt/entrypoints/engine.py index 6c5215e399f..7f80b8d2eec 100644 --- a/python/sglang/srt/entrypoints/engine.py +++ b/python/sglang/srt/entrypoints/engine.py @@ -101,10 +101,10 @@ def __init__(self, **kwargs): atexit.register(self.shutdown) # Launch subprocesses - tokenizer_manager, scheduler_info = _launch_subprocesses( + orchestrator, scheduler_info = _launch_subprocesses( server_args=server_args ) - self.tokenizer_manager = tokenizer_manager + self.orchestrator = orchestrator self.scheduler_info = scheduler_info def generate( @@ -137,7 +137,7 @@ def generate( stream=stream, ) loop = asyncio.get_event_loop() - generator = self.tokenizer_manager.generate_request(obj, None) + generator = self.orchestrator.generate_request(obj, None) if stream: @@ -183,7 +183,7 @@ async def async_generate( stream=stream, custom_logit_processor=custom_logit_processor, ) - generator = self.tokenizer_manager.generate_request(obj, None) + generator = self.orchestrator.generate_request(obj, None) if stream is True: return generator @@ -201,7 +201,7 @@ def encode( obj = EmbeddingReqInput(text=prompt) loop = asyncio.get_event_loop() - generator = self.tokenizer_manager.generate_request(obj, None) + generator = self.orchestrator.generate_request(obj, None) ret = loop.run_until_complete(generator.__anext__()) return ret @@ -210,14 +210,14 @@ def shutdown(self): kill_process_tree(os.getpid(), include_parent=False) def start_profile(self): - self.tokenizer_manager.start_profile() + self.orchestrator.start_profile() def stop_profile(self): - self.tokenizer_manager.stop_profile() + self.orchestrator.stop_profile() def get_server_info(self): return { - **dataclasses.asdict(self.tokenizer_manager.server_args), # server args + **dataclasses.asdict(self.orchestrator.server_args), # server args **self.scheduler_info, "version": __version__, } @@ -242,7 +242,7 @@ def init_weights_update_group( ) loop = asyncio.get_event_loop() return loop.run_until_complete( - self.tokenizer_manager.init_weights_update_group(obj, None) + self.orchestrator.init_weights_update_group(obj, None) ) def update_weights_from_distributed(self, name: str, dtype, shape): @@ -254,7 +254,7 @@ def update_weights_from_distributed(self, name: str, dtype, shape): ) loop = asyncio.get_event_loop() return loop.run_until_complete( - self.tokenizer_manager.update_weights_from_distributed(obj, None) + self.orchestrator.update_weights_from_distributed(obj, None) ) def update_weights_from_tensor(self, named_tensors: List[Tuple[str, torch.Tensor]]): @@ -264,7 +264,7 @@ def update_weights_from_tensor(self, named_tensors: List[Tuple[str, torch.Tensor ) loop = asyncio.get_event_loop() return loop.run_until_complete( - self.tokenizer_manager.update_weights_from_tensor(obj, None) + self.orchestrator.update_weights_from_tensor(obj, None) ) def get_weights_by_name(self, name: str, truncate_size: int = 100): @@ -272,7 +272,7 @@ def get_weights_by_name(self, name: str, truncate_size: int = 100): obj = GetWeightsByNameReqInput(name=name, truncate_size=truncate_size) loop = asyncio.get_event_loop() return loop.run_until_complete( - self.tokenizer_manager.get_weights_by_name(obj, None) + self.orchestrator.get_weights_by_name(obj, None) ) def release_memory_occupation(self): @@ -280,7 +280,7 @@ def release_memory_occupation(self): obj = ReleaseMemoryOccupationReqInput() loop = asyncio.get_event_loop() return loop.run_until_complete( - self.tokenizer_manager.release_memory_occupation(obj, None) + self.orchestrator.release_memory_occupation(obj, None) ) def resume_memory_occupation(self): @@ -288,7 +288,7 @@ def resume_memory_occupation(self): obj = ResumeMemoryOccupationReqInput() loop = asyncio.get_event_loop() return loop.run_until_complete( - self.tokenizer_manager.resume_memory_occupation(obj, None) + self.orchestrator.resume_memory_occupation(obj, None) ) @@ -420,9 +420,9 @@ def _launch_subprocesses(server_args: ServerArgs) -> Tuple[StdOrchestrator, Dict detoken_proc.start() # Launch tokenizer process - tokenizer_manager = StdOrchestrator(server_args, port_args) + orchestrator = StdOrchestrator(server_args, port_args) if server_args.chat_template: - load_chat_template_for_openai_api(tokenizer_manager, server_args.chat_template) + load_chat_template_for_openai_api(orchestrator, server_args.chat_template) # Wait for the model to finish loading scheduler_infos = [] @@ -445,5 +445,5 @@ def _launch_subprocesses(server_args: ServerArgs) -> Tuple[StdOrchestrator, Dict # Assume all schedulers have the same scheduler_info scheduler_info = scheduler_infos[0] - tokenizer_manager.max_req_input_len = scheduler_info["max_req_input_len"] - return tokenizer_manager, scheduler_info + orchestrator.max_req_input_len = scheduler_info["max_req_input_len"] + return orchestrator, scheduler_info diff --git a/python/sglang/srt/entrypoints/http_server.py b/python/sglang/srt/entrypoints/http_server.py index 36c3ed3092a..2ea359a9400 100644 --- a/python/sglang/srt/entrypoints/http_server.py +++ b/python/sglang/srt/entrypoints/http_server.py @@ -95,7 +95,7 @@ # Store global states @dataclasses.dataclass class _GlobalState: - tokenizer_manager: StdOrchestrator + orchestrator: StdOrchestrator scheduler_info: Dict @@ -122,7 +122,7 @@ async def health_generate(request: Request) -> Response: sampling_params = {"max_new_tokens": 1, "temperature": 0.7} - if _global_state.tokenizer_manager.is_generation: + if _global_state.orchestrator.is_generation: gri = GenerateReqInput( input_ids=[0], sampling_params=sampling_params, log_metrics=False ) @@ -132,7 +132,7 @@ async def health_generate(request: Request) -> Response: ) try: - async for _ in _global_state.tokenizer_manager.generate_request(gri, request): + async for _ in _global_state.orchestrator.generate_request(gri, request): break return Response(status_code=200) except Exception as e: @@ -144,9 +144,9 @@ async def health_generate(request: Request) -> Response: async def get_model_info(): """Get the model information.""" result = { - "model_path": _global_state.tokenizer_manager.model_path, - "tokenizer_path": _global_state.tokenizer_manager.server_args.tokenizer_path, - "is_generation": _global_state.tokenizer_manager.is_generation, + "model_path": _global_state.orchestrator.model_path, + "tokenizer_path": _global_state.orchestrator.server_args.tokenizer_path, + "is_generation": _global_state.orchestrator.is_generation, } return result @@ -154,7 +154,7 @@ async def get_model_info(): @app.get("/get_server_info") async def get_server_info(): return { - **dataclasses.asdict(_global_state.tokenizer_manager.server_args), + **dataclasses.asdict(_global_state.orchestrator.server_args), **_global_state.scheduler_info, "version": __version__, } @@ -168,7 +168,7 @@ async def generate_request(obj: GenerateReqInput, request: Request): async def stream_results() -> AsyncIterator[bytes]: try: - async for out in _global_state.tokenizer_manager.generate_request( + async for out in _global_state.orchestrator.generate_request( obj, request ): yield b"data: " + orjson.dumps( @@ -184,11 +184,11 @@ async def stream_results() -> AsyncIterator[bytes]: return StreamingResponse( stream_results(), media_type="text/event-stream", - background=_global_state.tokenizer_manager.create_abort_task(obj), + background=_global_state.orchestrator.create_abort_task(obj), ) else: try: - ret = await _global_state.tokenizer_manager.generate_request( + ret = await _global_state.orchestrator.generate_request( obj, request ).__anext__() return ret @@ -201,7 +201,7 @@ async def stream_results() -> AsyncIterator[bytes]: async def encode_request(obj: EmbeddingReqInput, request: Request): """Handle an embedding request.""" try: - ret = await _global_state.tokenizer_manager.generate_request( + ret = await _global_state.orchestrator.generate_request( obj, request ).__anext__() return ret @@ -213,7 +213,7 @@ async def encode_request(obj: EmbeddingReqInput, request: Request): async def classify_request(obj: EmbeddingReqInput, request: Request): """Handle a reward model request. Now the arguments and return values are the same as embedding models.""" try: - ret = await _global_state.tokenizer_manager.generate_request( + ret = await _global_state.orchestrator.generate_request( obj, request ).__anext__() return ret @@ -224,7 +224,7 @@ async def classify_request(obj: EmbeddingReqInput, request: Request): @app.post("/flush_cache") async def flush_cache(): """Flush the radix cache.""" - _global_state.tokenizer_manager.flush_cache() + _global_state.orchestrator.flush_cache() return Response( content="Cache flushed.\nPlease check backend logs for more details. " "(When there are running or waiting requests, the operation will not be performed.)\n", @@ -235,7 +235,7 @@ async def flush_cache(): @app.api_route("/start_profile", methods=["GET", "POST"]) async def start_profile_async(): """Start profiling.""" - _global_state.tokenizer_manager.start_profile() + _global_state.orchestrator.start_profile() return Response( content="Start profiling.\n", status_code=200, @@ -245,7 +245,7 @@ async def start_profile_async(): @app.api_route("/stop_profile", methods=["GET", "POST"]) async def stop_profile_async(): """Stop profiling.""" - _global_state.tokenizer_manager.stop_profile() + _global_state.orchestrator.stop_profile() return Response( content="Stop profiling. This will take some time.\n", status_code=200, @@ -255,7 +255,7 @@ async def stop_profile_async(): @app.post("/update_weights_from_disk") async def update_weights_from_disk(obj: UpdateWeightFromDiskReqInput, request: Request): """Update the weights from disk in-place without re-launching the server.""" - success, message = await _global_state.tokenizer_manager.update_weights_from_disk( + success, message = await _global_state.orchestrator.update_weights_from_disk( obj, request ) content = {"success": success, "message": message} @@ -276,7 +276,7 @@ async def init_weights_update_group( obj: InitWeightsUpdateGroupReqInput, request: Request ): """Initialize the parameter update group.""" - success, message = await _global_state.tokenizer_manager.init_weights_update_group( + success, message = await _global_state.orchestrator.init_weights_update_group( obj, request ) content = {"success": success, "message": message} @@ -292,7 +292,7 @@ async def update_weights_from_distributed( ): """Update model parameter from distributed online.""" success, message = ( - await _global_state.tokenizer_manager.update_weights_from_distributed( + await _global_state.orchestrator.update_weights_from_distributed( obj, request ) ) @@ -307,7 +307,7 @@ async def update_weights_from_distributed( async def get_weights_by_name(obj: GetWeightsByNameReqInput, request: Request): """Get model parameter by name.""" try: - ret = await _global_state.tokenizer_manager.get_weights_by_name(obj, request) + ret = await _global_state.orchestrator.get_weights_by_name(obj, request) if ret is None: return _create_error_response("Get parameter by name failed") else: @@ -322,7 +322,7 @@ async def release_memory_occupation( ): """Release GPU occupation temporarily""" try: - await _global_state.tokenizer_manager.release_memory_occupation(obj, request) + await _global_state.orchestrator.release_memory_occupation(obj, request) except Exception as e: return _create_error_response(e) @@ -333,7 +333,7 @@ async def resume_memory_occupation( ): """Resume GPU occupation""" try: - await _global_state.tokenizer_manager.resume_memory_occupation(obj, request) + await _global_state.orchestrator.resume_memory_occupation(obj, request) except Exception as e: return _create_error_response(e) @@ -342,7 +342,7 @@ async def resume_memory_occupation( async def open_session(obj: OpenSessionReqInput, request: Request): """Open a session, and return its unique session id.""" try: - session_id = await _global_state.tokenizer_manager.open_session(obj, request) + session_id = await _global_state.orchestrator.open_session(obj, request) if session_id is None: raise Exception( "Failed to open the session. Check if a session with the same id is still open." @@ -356,7 +356,7 @@ async def open_session(obj: OpenSessionReqInput, request: Request): async def close_session(obj: CloseSessionReqInput, request: Request): """Close the session""" try: - await _global_state.tokenizer_manager.close_session(obj, request) + await _global_state.orchestrator.close_session(obj, request) return Response(status_code=200) except Exception as e: return _create_error_response(e) @@ -365,7 +365,7 @@ async def close_session(obj: CloseSessionReqInput, request: Request): @app.api_route("/configure_logging", methods=["GET", "POST"]) async def configure_logging(obj: ConfigureLoggingReq, request: Request): """Close the session""" - _global_state.tokenizer_manager.configure_logging(obj) + _global_state.orchestrator.configure_logging(obj) return Response(status_code=200) @@ -374,24 +374,24 @@ async def configure_logging(obj: ConfigureLoggingReq, request: Request): @app.post("/v1/completions") async def openai_v1_completions(raw_request: Request): - return await v1_completions(_global_state.tokenizer_manager, raw_request) + return await v1_completions(_global_state.orchestrator, raw_request) @app.post("/v1/chat/completions") async def openai_v1_chat_completions(raw_request: Request): - return await v1_chat_completions(_global_state.tokenizer_manager, raw_request) + return await v1_chat_completions(_global_state.orchestrator, raw_request) @app.post("/v1/embeddings", response_class=ORJSONResponse) async def openai_v1_embeddings(raw_request: Request): - response = await v1_embeddings(_global_state.tokenizer_manager, raw_request) + response = await v1_embeddings(_global_state.orchestrator, raw_request) return response @app.get("/v1/models", response_class=ORJSONResponse) def available_models(): """Show available models.""" - served_model_names = [_global_state.tokenizer_manager.served_model_name] + served_model_names = [_global_state.orchestrator.served_model_name] model_cards = [] for served_model_name in served_model_names: model_cards.append(ModelCard(id=served_model_name, root=served_model_name)) @@ -401,7 +401,7 @@ def available_models(): @app.post("/v1/files") async def openai_v1_files(file: UploadFile = File(...), purpose: str = Form("batch")): return await v1_files_create( - file, purpose, _global_state.tokenizer_manager.server_args.file_storage_pth + file, purpose, _global_state.orchestrator.server_args.file_storage_pth ) @@ -413,13 +413,13 @@ async def delete_file(file_id: str): @app.post("/v1/batches") async def openai_v1_batches(raw_request: Request): - return await v1_batches(_global_state.tokenizer_manager, raw_request) + return await v1_batches(_global_state.orchestrator, raw_request) @app.post("/v1/batches/{batch_id}/cancel") async def cancel_batches(batch_id: str): # https://platform.openai.com/docs/api-reference/batch/cancel - return await v1_cancel_batch(_global_state.tokenizer_manager, batch_id) + return await v1_cancel_batch(_global_state.orchestrator, batch_id) @app.get("/v1/batches/{batch_id}") @@ -464,10 +464,10 @@ def launch_server( 1. The HTTP server, Engine, and StdOrchestrator both run in the main process. 2. Inter-process communication is done through ICP (each process uses a different port) via the ZMQ library. """ - tokenizer_manager, scheduler_info = _launch_subprocesses(server_args=server_args) + orchestrator, scheduler_info = _launch_subprocesses(server_args=server_args) set_global_state( _GlobalState( - tokenizer_manager=tokenizer_manager, + orchestrator=orchestrator, scheduler_info=scheduler_info, ) ) @@ -487,7 +487,7 @@ def launch_server( args=( server_args, pipe_finish_writer, - _global_state.tokenizer_manager.image_token_id, + _global_state.orchestrator.image_token_id, ), ) t.start() diff --git a/python/sglang/srt/openai_api/adapter.py b/python/sglang/srt/openai_api/adapter.py index 5056ba22ef9..2222a8d2047 100644 --- a/python/sglang/srt/openai_api/adapter.py +++ b/python/sglang/srt/openai_api/adapter.py @@ -116,7 +116,7 @@ def create_streaming_error_response( return json_str -def load_chat_template_for_openai_api(tokenizer_manager, chat_template_arg): +def load_chat_template_for_openai_api(orchestrator, chat_template_arg): global chat_template_name logger.info( @@ -131,7 +131,7 @@ def load_chat_template_for_openai_api(tokenizer_manager, chat_template_arg): if chat_template_arg.endswith(".jinja"): with open(chat_template_arg, "r") as f: chat_template = "".join(f.readlines()).strip("\n") - tokenizer_manager.tokenizer.chat_template = chat_template.replace( + orchestrator.tokenizer.chat_template = chat_template.replace( "\\n", "\n" ) chat_template_name = None @@ -217,7 +217,7 @@ async def v1_delete_file(file_id: str): return FileDeleteResponse(id=file_id, deleted=True) -async def v1_batches(tokenizer_manager, raw_request: Request): +async def v1_batches(orchestrator, raw_request: Request): try: body = await raw_request.json() @@ -238,7 +238,7 @@ async def v1_batches(tokenizer_manager, raw_request: Request): batch_storage[batch_id] = batch_response # Start processing the batch asynchronously - asyncio.create_task(process_batch(tokenizer_manager, batch_id, batch_request)) + asyncio.create_task(process_batch(orchestrator, batch_id, batch_request)) # Return the initial batch_response return batch_response @@ -249,7 +249,7 @@ async def v1_batches(tokenizer_manager, raw_request: Request): return {"error": str(e)} -async def process_batch(tokenizer_manager, batch_id: str, batch_request: BatchRequest): +async def process_batch(orchestrator, batch_id: str, batch_request: BatchRequest): try: # Update the batch status to "in_progress" batch_storage[batch_id].status = "in_progress" @@ -292,7 +292,7 @@ async def process_batch(tokenizer_manager, batch_id: str, batch_request: BatchRe if end_point == "/v1/chat/completions": adapted_request, request = v1_chat_generate_request( - all_requests, tokenizer_manager, request_ids=request_ids + all_requests, orchestrator, request_ids=request_ids ) elif end_point == "/v1/completions": adapted_request, request = v1_generate_request( @@ -300,7 +300,7 @@ async def process_batch(tokenizer_manager, batch_id: str, batch_request: BatchRe ) try: - ret = await tokenizer_manager.generate_request(adapted_request).__anext__() + ret = await orchestrator.generate_request(adapted_request).__anext__() if not isinstance(ret, list): ret = [ret] if end_point == "/v1/chat/completions": @@ -308,11 +308,11 @@ async def process_batch(tokenizer_manager, batch_id: str, batch_request: BatchRe request, ret, to_file=True, - cache_report=tokenizer_manager.server_args.enable_cache_report, + cache_report=orchestrator.server_args.enable_cache_report, ) else: responses = v1_generate_response( - request, ret, tokenizer_manager, to_file=True + request, ret, orchestrator, to_file=True ) except Exception as e: @@ -384,7 +384,7 @@ async def v1_retrieve_batch(batch_id: str): return batch_response -async def v1_cancel_batch(tokenizer_manager, batch_id: str): +async def v1_cancel_batch(orchestrator, batch_id: str): # Retrieve the batch job from the in-memory storage batch_response = batch_storage.get(batch_id) if batch_response is None: @@ -395,7 +395,7 @@ async def v1_cancel_batch(tokenizer_manager, batch_id: str): # Start cancelling the batch asynchronously asyncio.create_task( cancel_batch( - tokenizer_manager=tokenizer_manager, + orchestrator=orchestrator, batch_id=batch_id, input_file_id=batch_response.input_file_id, ) @@ -412,7 +412,7 @@ async def v1_cancel_batch(tokenizer_manager, batch_id: str): ) -async def cancel_batch(tokenizer_manager, batch_id: str, input_file_id: str): +async def cancel_batch(orchestrator, batch_id: str, input_file_id: str): try: # Update the batch status to "cancelling" batch_storage[batch_id].status = "cancelling" @@ -436,7 +436,7 @@ async def cancel_batch(tokenizer_manager, batch_id: str, input_file_id: str): # Cancel requests by request_ids for rid in request_ids: - tokenizer_manager.abort_request(rid=rid) + orchestrator.abort_request(rid=rid) retrieve_batch = batch_storage[batch_id] retrieve_batch.status = "cancelled" @@ -564,7 +564,7 @@ def v1_generate_request( return adapted_request, all_requests if len(all_requests) > 1 else all_requests[0] -def v1_generate_response(request, ret, tokenizer_manager, to_file=False): +def v1_generate_response(request, ret, orchestrator, to_file=False): choices = [] echo = False @@ -576,13 +576,13 @@ def v1_generate_response(request, ret, tokenizer_manager, to_file=False): elif isinstance(request.prompt, list) and isinstance(request.prompt[0], list): # for the case of multiple token ids prompts prompts = [ - tokenizer_manager.tokenizer.decode(prompt, skip_special_tokens=True) + orchestrator.tokenizer.decode(prompt, skip_special_tokens=True) for prompt in request.prompt ] elif isinstance(request.prompt, list) and isinstance(request.prompt[0], int): # for the case of single token ids prompt prompts = [ - tokenizer_manager.tokenizer.decode( + orchestrator.tokenizer.decode( request.prompt, skip_special_tokens=True ) ] @@ -694,7 +694,7 @@ def v1_generate_response(request, ret, tokenizer_manager, to_file=False): return response -async def v1_completions(tokenizer_manager, raw_request: Request): +async def v1_completions(orchestrator, raw_request: Request): request_json = await raw_request.json() all_requests = [CompletionRequest(**request_json)] adapted_request, request = v1_generate_request(all_requests) @@ -707,7 +707,7 @@ async def generate_stream_resp(): prompt_tokens = {} completion_tokens = {} try: - async for content in tokenizer_manager.generate_request( + async for content in orchestrator.generate_request( adapted_request, raw_request ): index = content.get("index", 0) @@ -730,14 +730,14 @@ async def generate_stream_resp(): prompts = request.prompt[index // request.n] elif isinstance(request.prompt[0], int): # for the case of single token ids prompt - prompts = tokenizer_manager.tokenizer.decode( + prompts = orchestrator.tokenizer.decode( request.prompt, skip_special_tokens=True ) elif isinstance(request.prompt[0], list) and isinstance( request.prompt[0][0], int ): # for the case of multiple token ids prompts - prompts = tokenizer_manager.tokenizer.decode( + prompts = orchestrator.tokenizer.decode( request.prompt[index // request.n], skip_special_tokens=True, ) @@ -832,12 +832,12 @@ async def generate_stream_resp(): return StreamingResponse( generate_stream_resp(), media_type="text/event-stream", - background=tokenizer_manager.create_abort_task(adapted_request), + background=orchestrator.create_abort_task(adapted_request), ) # Non-streaming response. try: - ret = await tokenizer_manager.generate_request( + ret = await orchestrator.generate_request( adapted_request, raw_request ).__anext__() except ValueError as e: @@ -846,13 +846,13 @@ async def generate_stream_resp(): if not isinstance(ret, list): ret = [ret] - response = v1_generate_response(request, ret, tokenizer_manager) + response = v1_generate_response(request, ret, orchestrator) return response def v1_chat_generate_request( all_requests: List[ChatCompletionRequest], - tokenizer_manager, + orchestrator, request_ids: List[str] = None, ): input_ids = [] @@ -908,14 +908,14 @@ def v1_chat_generate_request( openai_compatible_messages = openai_compatible_messages[:-1] else: assistant_prefix = None - prompt_ids = tokenizer_manager.tokenizer.apply_chat_template( + prompt_ids = orchestrator.tokenizer.apply_chat_template( openai_compatible_messages, tokenize=True, add_generation_prompt=True, tools=tools, ) if assistant_prefix: - prompt_ids += tokenizer_manager.tokenizer.encode(assistant_prefix) + prompt_ids += orchestrator.tokenizer.encode(assistant_prefix) stop = request.stop image_data = None modalities = [] @@ -930,7 +930,7 @@ def v1_chat_generate_request( stop.append(request.stop) else: stop.extend(request.stop) - prompt_ids = tokenizer_manager.tokenizer.encode(prompt) + prompt_ids = orchestrator.tokenizer.encode(prompt) else: # Use the raw prompt and stop strings if the messages is already a string. prompt_ids = request.messages @@ -1166,10 +1166,10 @@ def v1_chat_generate_response(request, ret, to_file=False, cache_report=False): return response -async def v1_chat_completions(tokenizer_manager, raw_request: Request): +async def v1_chat_completions(orchestrator, raw_request: Request): request_json = await raw_request.json() all_requests = [ChatCompletionRequest(**request_json)] - adapted_request, request = v1_chat_generate_request(all_requests, tokenizer_manager) + adapted_request, request = v1_chat_generate_request(all_requests, orchestrator) if adapted_request.stream: @@ -1180,7 +1180,7 @@ async def generate_stream_resp(): prompt_tokens = {} completion_tokens = {} try: - async for content in tokenizer_manager.generate_request( + async for content in orchestrator.generate_request( adapted_request, raw_request ): index = content.get("index", 0) @@ -1319,12 +1319,12 @@ async def generate_stream_resp(): return StreamingResponse( generate_stream_resp(), media_type="text/event-stream", - background=tokenizer_manager.create_abort_task(adapted_request), + background=orchestrator.create_abort_task(adapted_request), ) # Non-streaming response. try: - ret = await tokenizer_manager.generate_request( + ret = await orchestrator.generate_request( adapted_request, raw_request ).__anext__() except ValueError as e: @@ -1333,13 +1333,13 @@ async def generate_stream_resp(): ret = [ret] response = v1_chat_generate_response( - request, ret, cache_report=tokenizer_manager.server_args.enable_cache_report + request, ret, cache_report=orchestrator.server_args.enable_cache_report ) return response -def v1_embedding_request(all_requests, tokenizer_manager): +def v1_embedding_request(all_requests, orchestrator): prompts = [] sampling_params_list = [] first_prompt_type = type(all_requests[0].input) @@ -1394,13 +1394,13 @@ def v1_embedding_response(ret, model_path, to_file=False): ) -async def v1_embeddings(tokenizer_manager, raw_request: Request): +async def v1_embeddings(orchestrator, raw_request: Request): request_json = await raw_request.json() all_requests = [EmbeddingRequest(**request_json)] - adapted_request, request = v1_embedding_request(all_requests, tokenizer_manager) + adapted_request, request = v1_embedding_request(all_requests, orchestrator) try: - ret = await tokenizer_manager.generate_request( + ret = await orchestrator.generate_request( adapted_request, raw_request ).__anext__() except ValueError as e: @@ -1409,7 +1409,7 @@ async def v1_embeddings(tokenizer_manager, raw_request: Request): if not isinstance(ret, list): ret = [ret] - response = v1_embedding_response(ret, tokenizer_manager.model_path) + response = v1_embedding_response(ret, orchestrator.model_path) return response diff --git a/python/sglang/srt/orchestration/std/orchestrator.py b/python/sglang/srt/orchestration/std/orchestrator.py index 9e57150aff9..77670bda667 100644 --- a/python/sglang/srt/orchestration/std/orchestrator.py +++ b/python/sglang/srt/orchestration/std/orchestrator.py @@ -454,14 +454,14 @@ async def print_exception_wrapper(func): class SignalHandler: - def __init__(self, tokenizer_manager): - self.tokenizer_manager = tokenizer_manager + def __init__(self, orchestrator): + self.orchestrator = orchestrator def signal_handler(self, signum=None, frame=None): logger.warning( f"SIGTERM received. {signum=} {frame=}. Draining requests and shutting down..." ) - self.tokenizer_manager.gracefully_exit = True + self.orchestrator.gracefully_exit = True T = TypeVar("T") From b1932a630f6d37f508625d836b641f64f43e1eea Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Sat, 25 Jan 2025 10:08:03 +0800 Subject: [PATCH 048/136] handle max_req_input_len --- python/sglang/srt/entrypoints/engine.py | 2 +- python/sglang/srt/managers/generation_manager.py | 16 +++++++++------- python/sglang/srt/managers/tokenizer_manager.py | 4 ++-- 3 files changed, 12 insertions(+), 10 deletions(-) diff --git a/python/sglang/srt/entrypoints/engine.py b/python/sglang/srt/entrypoints/engine.py index 310e92c23d9..d1f71f25f8f 100644 --- a/python/sglang/srt/entrypoints/engine.py +++ b/python/sglang/srt/entrypoints/engine.py @@ -445,5 +445,5 @@ def _launch_subprocesses(server_args: ServerArgs) -> Tuple[TokenizerManager, Dic # Assume all schedulers have the same scheduler_info scheduler_info = scheduler_infos[0] - tokenizer_manager.max_req_input_len = scheduler_info["max_req_input_len"] + tokenizer_manager.configure_max_req_input_len(scheduler_info["max_req_input_len"]) return tokenizer_manager, scheduler_info diff --git a/python/sglang/srt/managers/generation_manager.py b/python/sglang/srt/managers/generation_manager.py index b32e66c6046..0827a2e5241 100644 --- a/python/sglang/srt/managers/generation_manager.py +++ b/python/sglang/srt/managers/generation_manager.py @@ -10,7 +10,6 @@ from typing import Any, Callable, Dict, List, Optional, Tuple, Union import fastapi - from sglang.srt.configs.model_config import ModelConfig from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer from sglang.srt.managers.image_processor import ( @@ -46,7 +45,7 @@ def __init__( self.on_request = on_request self.model_config = _compute_model_config(server_args) - self._generation_converter = GenerationConverter(server_args=server_args) + self.generation_converter = GenerationConverter(server_args=server_args) self.rid_to_state: Dict[str, _ReqState] = {} @@ -80,7 +79,7 @@ async def generate( is_single = obj.is_single if is_single: - tokenized_obj = await self._generation_converter.tokenize_request(obj) + tokenized_obj = await self.generation_converter.tokenize_request(obj) self._send_one_request(obj, tokenized_obj, created_time) async for response in self._wait_one_response(obj, request): yield response @@ -162,7 +161,7 @@ async def _handle_batch_request( # Send all requests for i in range(batch_size): tmp_obj = obj[i] - tokenized_obj = await self._generation_converter.tokenize_request( + tokenized_obj = await self.generation_converter.tokenize_request( tmp_obj ) self._send_one_request(tmp_obj, tokenized_obj, created_time) @@ -180,7 +179,7 @@ async def _handle_batch_request( # Tokenize all requests objs = [obj[i] for i in range(batch_size)] tokenized_objs = await asyncio.gather( - *(self._generation_converter.tokenize_request(obj) for obj in objs) + *(self.generation_converter.tokenize_request(obj) for obj in objs) ) # Cache the common prefix for parallel sampling @@ -236,7 +235,7 @@ def handle_batch_output( if state is None: continue - out_dict = self._generation_converter.postprocess_response( + out_dict = self.generation_converter.postprocess_response( recv_obj, index, state.obj ) @@ -264,7 +263,7 @@ def abort_request(self, rid: str): @property def tokenizer(self): - return self._generation_converter.tokenizer + return self.generation_converter.tokenizer class GenerationConverter: @@ -280,6 +279,9 @@ def __init__( # Create image processor placeholder self.image_processor = get_dummy_image_processor() + # Set after scheduler is initialized + self.max_req_input_len = None + # Create tokenizer if server_args.skip_tokenizer_init: self.tokenizer = self.processor = None diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py index 03527c60280..b6960b1c3bc 100644 --- a/python/sglang/srt/managers/tokenizer_manager.py +++ b/python/sglang/srt/managers/tokenizer_manager.py @@ -130,8 +130,6 @@ def __init__( self.resume_memory_occupation_communicator = _Communicator( self.send_to_scheduler, server_args.dp_size ) - # Set after scheduler is initialized - self.max_req_input_len = None self._result_dispatcher = TypeBasedDispatcher( [ @@ -440,6 +438,8 @@ def tokenizer(self): def image_token_id(self): return self._generation_manager.model_config.image_token_id + def configure_max_req_input_len(self, max_req_input_len): + self._generation_manager.generation_converter.max_req_input_len = max_req_input_len async def print_exception_wrapper(func): """ From 559ecbafe2ebc893b4c7d8de2781ef332892271d Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Sat, 25 Jan 2025 10:08:10 +0800 Subject: [PATCH 049/136] fmt --- python/sglang/srt/managers/generation_manager.py | 1 + python/sglang/srt/managers/tokenizer_manager.py | 5 ++++- 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/python/sglang/srt/managers/generation_manager.py b/python/sglang/srt/managers/generation_manager.py index 0827a2e5241..f87d5d75fe9 100644 --- a/python/sglang/srt/managers/generation_manager.py +++ b/python/sglang/srt/managers/generation_manager.py @@ -10,6 +10,7 @@ from typing import Any, Callable, Dict, List, Optional, Tuple, Union import fastapi + from sglang.srt.configs.model_config import ModelConfig from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer from sglang.srt.managers.image_processor import ( diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py index b6960b1c3bc..6d86602d128 100644 --- a/python/sglang/srt/managers/tokenizer_manager.py +++ b/python/sglang/srt/managers/tokenizer_manager.py @@ -439,7 +439,10 @@ def image_token_id(self): return self._generation_manager.model_config.image_token_id def configure_max_req_input_len(self, max_req_input_len): - self._generation_manager.generation_converter.max_req_input_len = max_req_input_len + self._generation_manager.generation_converter.max_req_input_len = ( + max_req_input_len + ) + async def print_exception_wrapper(func): """ From fa4d9f6b4ca0d8e9af2ee33eb916e53cd1722920 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Sat, 25 Jan 2025 10:09:11 +0800 Subject: [PATCH 050/136] fmt --- python/sglang/srt/entrypoints/engine.py | 10 +++------- python/sglang/srt/entrypoints/http_server.py | 8 +++----- python/sglang/srt/openai_api/adapter.py | 8 ++------ 3 files changed, 8 insertions(+), 18 deletions(-) diff --git a/python/sglang/srt/entrypoints/engine.py b/python/sglang/srt/entrypoints/engine.py index 213581756f3..e3cb85d71a4 100644 --- a/python/sglang/srt/entrypoints/engine.py +++ b/python/sglang/srt/entrypoints/engine.py @@ -48,8 +48,8 @@ UpdateWeightsFromTensorReqInput, ) from sglang.srt.managers.scheduler import run_scheduler_process -from sglang.srt.orchestration.std.orchestrator import StdOrchestrator from sglang.srt.openai_api.adapter import load_chat_template_for_openai_api +from sglang.srt.orchestration.std.orchestrator import StdOrchestrator from sglang.srt.server_args import PortArgs, ServerArgs from sglang.srt.torch_memory_saver_adapter import TorchMemorySaverAdapter from sglang.srt.utils import ( @@ -101,9 +101,7 @@ def __init__(self, **kwargs): atexit.register(self.shutdown) # Launch subprocesses - orchestrator, scheduler_info = _launch_subprocesses( - server_args=server_args - ) + orchestrator, scheduler_info = _launch_subprocesses(server_args=server_args) self.orchestrator = orchestrator self.scheduler_info = scheduler_info @@ -271,9 +269,7 @@ def get_weights_by_name(self, name: str, truncate_size: int = 100): """Get weights by parameter name.""" obj = GetWeightsByNameReqInput(name=name, truncate_size=truncate_size) loop = asyncio.get_event_loop() - return loop.run_until_complete( - self.orchestrator.get_weights_by_name(obj, None) - ) + return loop.run_until_complete(self.orchestrator.get_weights_by_name(obj, None)) def release_memory_occupation(self): """Release GPU occupation temporarily.""" diff --git a/python/sglang/srt/entrypoints/http_server.py b/python/sglang/srt/entrypoints/http_server.py index 2ea359a9400..d949877373a 100644 --- a/python/sglang/srt/entrypoints/http_server.py +++ b/python/sglang/srt/entrypoints/http_server.py @@ -52,7 +52,6 @@ UpdateWeightFromDiskReqInput, UpdateWeightsFromDistributedReqInput, ) -from sglang.srt.orchestration.std.orchestrator import StdOrchestrator from sglang.srt.metrics.func_timer import enable_func_timer from sglang.srt.openai_api.adapter import ( v1_batches, @@ -67,6 +66,7 @@ v1_retrieve_file_content, ) from sglang.srt.openai_api.protocol import ModelCard, ModelList +from sglang.srt.orchestration.std.orchestrator import StdOrchestrator from sglang.srt.server_args import ServerArgs from sglang.srt.utils import ( add_api_key_middleware, @@ -291,10 +291,8 @@ async def update_weights_from_distributed( obj: UpdateWeightsFromDistributedReqInput, request: Request ): """Update model parameter from distributed online.""" - success, message = ( - await _global_state.orchestrator.update_weights_from_distributed( - obj, request - ) + success, message = await _global_state.orchestrator.update_weights_from_distributed( + obj, request ) content = {"success": success, "message": message} if success: diff --git a/python/sglang/srt/openai_api/adapter.py b/python/sglang/srt/openai_api/adapter.py index 2222a8d2047..2189af24a18 100644 --- a/python/sglang/srt/openai_api/adapter.py +++ b/python/sglang/srt/openai_api/adapter.py @@ -131,9 +131,7 @@ def load_chat_template_for_openai_api(orchestrator, chat_template_arg): if chat_template_arg.endswith(".jinja"): with open(chat_template_arg, "r") as f: chat_template = "".join(f.readlines()).strip("\n") - orchestrator.tokenizer.chat_template = chat_template.replace( - "\\n", "\n" - ) + orchestrator.tokenizer.chat_template = chat_template.replace("\\n", "\n") chat_template_name = None else: assert chat_template_arg.endswith( @@ -582,9 +580,7 @@ def v1_generate_response(request, ret, orchestrator, to_file=False): elif isinstance(request.prompt, list) and isinstance(request.prompt[0], int): # for the case of single token ids prompt prompts = [ - orchestrator.tokenizer.decode( - request.prompt, skip_special_tokens=True - ) + orchestrator.tokenizer.decode(request.prompt, skip_special_tokens=True) ] else: # for the case of single str prompt From 71c29d9ad98d8bb0d57fe60f2c785030acc43b28 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Sat, 25 Jan 2025 10:11:03 +0800 Subject: [PATCH 051/136] fmt --- python/sglang/srt/orchestration/std/orchestrator.py | 1 + 1 file changed, 1 insertion(+) diff --git a/python/sglang/srt/orchestration/std/orchestrator.py b/python/sglang/srt/orchestration/std/orchestrator.py index 07a5d023866..f5b875e5de4 100644 --- a/python/sglang/srt/orchestration/std/orchestrator.py +++ b/python/sglang/srt/orchestration/std/orchestrator.py @@ -26,6 +26,7 @@ import zmq import zmq.asyncio from fastapi import BackgroundTasks + from sglang.srt.aio_rwlock import RWLock from sglang.srt.managers.generation_manager import GenerationManager from sglang.srt.managers.io_struct import ( From 4d5be794ec81064fd3a87234de874638a4a21500 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Sat, 25 Jan 2025 10:12:18 +0800 Subject: [PATCH 052/136] empty file --- .../srt/orchestration/std/detokenizer.py | 21 +++++++++++++++++++ 1 file changed, 21 insertions(+) create mode 100644 python/sglang/srt/orchestration/std/detokenizer.py diff --git a/python/sglang/srt/orchestration/std/detokenizer.py b/python/sglang/srt/orchestration/std/detokenizer.py new file mode 100644 index 00000000000..b6923c79c7e --- /dev/null +++ b/python/sglang/srt/orchestration/std/detokenizer.py @@ -0,0 +1,21 @@ +import logging +import signal + +import psutil +import setproctitle +import zmq + +from sglang.srt.managers.detokenizer_manager import DetokenizerManager +from sglang.srt.managers.io_struct import BatchEmbeddingOut, BatchTokenIDOut +from sglang.srt.server_args import PortArgs, ServerArgs +from sglang.srt.utils import configure_logger, get_zmq_socket +from sglang.utils import TypeBasedDispatcher, get_exception_traceback + +logger = logging.getLogger(__name__) + + +class DetokenizerManagerCommunicator: + def __init__(self, core: DetokenizerManager, port_args: PortArgs): + self.core = core + + TODO From 91c604d6781ae6e18fb34a2d0f1527b88a1121ec Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Sat, 25 Jan 2025 10:12:49 +0800 Subject: [PATCH 053/136] extract detokenizer_manager.init --- .../sglang/srt/managers/detokenizer_manager.py | 9 --------- .../srt/orchestration/std/detokenizer.py | 18 ++++++++++-------- 2 files changed, 10 insertions(+), 17 deletions(-) diff --git a/python/sglang/srt/managers/detokenizer_manager.py b/python/sglang/srt/managers/detokenizer_manager.py index 972f9595b2c..265ce841028 100644 --- a/python/sglang/srt/managers/detokenizer_manager.py +++ b/python/sglang/srt/managers/detokenizer_manager.py @@ -62,15 +62,6 @@ def __init__( server_args: ServerArgs, port_args: PortArgs, ): - # Init inter-process communication - context = zmq.Context(2) - self.recv_from_scheduler = get_zmq_socket( - context, zmq.PULL, port_args.detokenizer_ipc_name, True - ) - self.send_to_tokenizer = get_zmq_socket( - context, zmq.PUSH, port_args.tokenizer_ipc_name, False - ) - if server_args.skip_tokenizer_init: self.tokenizer = None else: diff --git a/python/sglang/srt/orchestration/std/detokenizer.py b/python/sglang/srt/orchestration/std/detokenizer.py index b6923c79c7e..39f9cd71a22 100644 --- a/python/sglang/srt/orchestration/std/detokenizer.py +++ b/python/sglang/srt/orchestration/std/detokenizer.py @@ -1,15 +1,10 @@ import logging -import signal -import psutil -import setproctitle import zmq from sglang.srt.managers.detokenizer_manager import DetokenizerManager -from sglang.srt.managers.io_struct import BatchEmbeddingOut, BatchTokenIDOut -from sglang.srt.server_args import PortArgs, ServerArgs -from sglang.srt.utils import configure_logger, get_zmq_socket -from sglang.utils import TypeBasedDispatcher, get_exception_traceback +from sglang.srt.server_args import PortArgs +from sglang.srt.utils import get_zmq_socket logger = logging.getLogger(__name__) @@ -18,4 +13,11 @@ class DetokenizerManagerCommunicator: def __init__(self, core: DetokenizerManager, port_args: PortArgs): self.core = core - TODO + # Init inter-process communication + context = zmq.Context(2) + self.recv_from_scheduler = get_zmq_socket( + context, zmq.PULL, port_args.detokenizer_ipc_name, True + ) + self.send_to_tokenizer = get_zmq_socket( + context, zmq.PUSH, port_args.tokenizer_ipc_name, False + ) From e5992fdcf9316b8e21d36b3de650fccf452f46c1 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Sat, 25 Jan 2025 10:15:31 +0800 Subject: [PATCH 054/136] extract --- .../srt/managers/detokenizer_manager.py | 194 +++++++++--------- 1 file changed, 100 insertions(+), 94 deletions(-) diff --git a/python/sglang/srt/managers/detokenizer_manager.py b/python/sglang/srt/managers/detokenizer_manager.py index 972f9595b2c..13d83f302fe 100644 --- a/python/sglang/srt/managers/detokenizer_manager.py +++ b/python/sglang/srt/managers/detokenizer_manager.py @@ -32,7 +32,7 @@ ) from sglang.srt.server_args import PortArgs, ServerArgs from sglang.srt.utils import configure_logger, get_zmq_socket -from sglang.utils import find_printable_text, get_exception_traceback +from sglang.utils import find_printable_text, get_exception_traceback, TypeBasedDispatcher logger = logging.getLogger(__name__) @@ -83,6 +83,13 @@ def __init__( self.decode_status = LimitedCapacityDict(capacity=DETOKENIZER_MAX_STATES) + self._request_dispatcher = TypeBasedDispatcher( + [ + (BatchEmbeddingOut, self.handle_batch_embedding_out), + (BatchTokenIDOut, self.handle_batch_token_id_out), + ] + ) + def trim_matched_stop( self, output: Union[str, List[int]], finished_reason: Dict, no_stop_trim: bool ): @@ -111,107 +118,106 @@ def event_loop(self): while True: recv_obj = self.recv_from_scheduler.recv_pyobj() - - if isinstance(recv_obj, BatchEmbeddingOut): - # If it is embedding model, no detokenization is needed. - self.send_to_tokenizer.send_pyobj(recv_obj) - continue + output = self._request_dispatcher(recv_obj) + self.send_to_tokenizer.send_pyobj(output) + + def handle_batch_embedding_out(self, recv_obj: BatchEmbeddingOut): + # If it is embedding model, no detokenization is needed. + return recv_obj + + def handle_batch_token_id_out(self, recv_obj: BatchTokenIDOut): + bs = len(recv_obj.rids) + + # Initialize decode status + read_ids, surr_ids = [], [] + for i in range(bs): + rid = recv_obj.rids[i] + vid = recv_obj.vids[i] + if rid not in self.decode_status or self.decode_status[rid].vid != vid: + s = DecodeStatus( + vid=vid, + decoded_text=recv_obj.decoded_texts[i], + decode_ids=recv_obj.decode_ids[i], + surr_offset=0, + read_offset=recv_obj.read_offsets[i], + ) + self.decode_status[rid] = s else: - assert isinstance(recv_obj, BatchTokenIDOut) - - bs = len(recv_obj.rids) - - # Initialize decode status - read_ids, surr_ids = [], [] - for i in range(bs): - rid = recv_obj.rids[i] - vid = recv_obj.vids[i] - if rid not in self.decode_status or self.decode_status[rid].vid != vid: - s = DecodeStatus( - vid=vid, - decoded_text=recv_obj.decoded_texts[i], - decode_ids=recv_obj.decode_ids[i], - surr_offset=0, - read_offset=recv_obj.read_offsets[i], - ) - self.decode_status[rid] = s - else: - s = self.decode_status[rid] - s.decode_ids = recv_obj.decode_ids[i] - - read_ids.append( - self.trim_matched_stop( - s.decode_ids[s.surr_offset :], - recv_obj.finished_reasons[i], - recv_obj.no_stop_trim[i], - ) + s = self.decode_status[rid] + s.decode_ids = recv_obj.decode_ids[i] + + read_ids.append( + self.trim_matched_stop( + s.decode_ids[s.surr_offset :], + recv_obj.finished_reasons[i], + recv_obj.no_stop_trim[i], ) - surr_ids.append(s.decode_ids[s.surr_offset : s.read_offset]) - - # TODO(lmzheng): handle skip_special_tokens/spaces_between_special_tokens per request - surr_texts = self.tokenizer.batch_decode( - surr_ids, - skip_special_tokens=recv_obj.skip_special_tokens[0], - spaces_between_special_tokens=recv_obj.spaces_between_special_tokens[0], - ) - read_texts = self.tokenizer.batch_decode( - read_ids, - skip_special_tokens=recv_obj.skip_special_tokens[0], - spaces_between_special_tokens=recv_obj.spaces_between_special_tokens[0], ) + surr_ids.append(s.decode_ids[s.surr_offset : s.read_offset]) - # Incremental decoding - output_strs = [] - for i in range(bs): - try: - s = self.decode_status[recv_obj.rids[i]] - except KeyError: - raise RuntimeError( - f"Decode status not found for request {recv_obj.rids[i]}. " - "It may be due to the request being evicted from the decode status due to memory pressure. " - "Please increase the maximum number of requests by setting " - "the SGLANG_DETOKENIZER_MAX_STATES environment variable to a bigger value than the default value. " - f"The current value is {DETOKENIZER_MAX_STATES}. " - "For more details, see: https://github.com/sgl-project/sglang/issues/2812" - ) - new_text = read_texts[i][len(surr_texts[i]) :] - if recv_obj.finished_reasons[i] is None: - # Streaming chunk: update the decode status - if len(new_text) > 0 and not new_text.endswith("�"): - s.decoded_text = s.decoded_text + new_text - s.surr_offset = s.read_offset - s.read_offset = len(s.decode_ids) - new_text = "" - else: - new_text = find_printable_text(new_text) - - output_strs.append( - self.trim_matched_stop( - s.decoded_text + new_text, - recv_obj.finished_reasons[i], - recv_obj.no_stop_trim[i], - ) - ) + # TODO(lmzheng): handle skip_special_tokens/spaces_between_special_tokens per request + surr_texts = self.tokenizer.batch_decode( + surr_ids, + skip_special_tokens=recv_obj.skip_special_tokens[0], + spaces_between_special_tokens=recv_obj.spaces_between_special_tokens[0], + ) + read_texts = self.tokenizer.batch_decode( + read_ids, + skip_special_tokens=recv_obj.skip_special_tokens[0], + spaces_between_special_tokens=recv_obj.spaces_between_special_tokens[0], + ) - self.send_to_tokenizer.send_pyobj( - BatchStrOut( - rids=recv_obj.rids, - finished_reasons=recv_obj.finished_reasons, - output_strs=output_strs, - prompt_tokens=recv_obj.prompt_tokens, - completion_tokens=recv_obj.completion_tokens, - cached_tokens=recv_obj.cached_tokens, - input_token_logprobs_val=recv_obj.input_token_logprobs_val, - input_token_logprobs_idx=recv_obj.input_token_logprobs_idx, - output_token_logprobs_val=recv_obj.output_token_logprobs_val, - output_token_logprobs_idx=recv_obj.output_token_logprobs_idx, - input_top_logprobs_val=recv_obj.input_top_logprobs_val, - input_top_logprobs_idx=recv_obj.input_top_logprobs_idx, - output_top_logprobs_val=recv_obj.output_top_logprobs_val, - output_top_logprobs_idx=recv_obj.output_top_logprobs_idx, + # Incremental decoding + output_strs = [] + for i in range(bs): + try: + s = self.decode_status[recv_obj.rids[i]] + except KeyError: + raise RuntimeError( + f"Decode status not found for request {recv_obj.rids[i]}. " + "It may be due to the request being evicted from the decode status due to memory pressure. " + "Please increase the maximum number of requests by setting " + "the SGLANG_DETOKENIZER_MAX_STATES environment variable to a bigger value than the default value. " + f"The current value is {DETOKENIZER_MAX_STATES}. " + "For more details, see: https://github.com/sgl-project/sglang/issues/2812" ) + new_text = read_texts[i][len(surr_texts[i]) :] + if recv_obj.finished_reasons[i] is None: + # Streaming chunk: update the decode status + if len(new_text) > 0 and not new_text.endswith("�"): + s.decoded_text = s.decoded_text + new_text + s.surr_offset = s.read_offset + s.read_offset = len(s.decode_ids) + new_text = "" + else: + new_text = find_printable_text(new_text) + + output_strs.append( + self.trim_matched_stop( + s.decoded_text + new_text, + recv_obj.finished_reasons[i], + recv_obj.no_stop_trim[i], + ) ) + return BatchStrOut( + rids=recv_obj.rids, + finished_reasons=recv_obj.finished_reasons, + output_strs=output_strs, + prompt_tokens=recv_obj.prompt_tokens, + completion_tokens=recv_obj.completion_tokens, + cached_tokens=recv_obj.cached_tokens, + input_token_logprobs_val=recv_obj.input_token_logprobs_val, + input_token_logprobs_idx=recv_obj.input_token_logprobs_idx, + output_token_logprobs_val=recv_obj.output_token_logprobs_val, + output_token_logprobs_idx=recv_obj.output_token_logprobs_idx, + input_top_logprobs_val=recv_obj.input_top_logprobs_val, + input_top_logprobs_idx=recv_obj.input_top_logprobs_idx, + output_top_logprobs_val=recv_obj.output_top_logprobs_val, + output_top_logprobs_idx=recv_obj.output_top_logprobs_idx, + ) + + class LimitedCapacityDict(OrderedDict): def __init__(self, capacity: int, *args, **kwargs): From 2dd8e8fcf85e843204c87b9eae574b6583c059b1 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Sat, 25 Jan 2025 10:17:36 +0800 Subject: [PATCH 055/136] mv _request_dispatcher --- python/sglang/srt/managers/detokenizer_manager.py | 7 ------- python/sglang/srt/orchestration/std/detokenizer.py | 10 +++++++++- 2 files changed, 9 insertions(+), 8 deletions(-) diff --git a/python/sglang/srt/managers/detokenizer_manager.py b/python/sglang/srt/managers/detokenizer_manager.py index 18f91a56069..509bc224915 100644 --- a/python/sglang/srt/managers/detokenizer_manager.py +++ b/python/sglang/srt/managers/detokenizer_manager.py @@ -74,13 +74,6 @@ def __init__( self.decode_status = LimitedCapacityDict(capacity=DETOKENIZER_MAX_STATES) - self._request_dispatcher = TypeBasedDispatcher( - [ - (BatchEmbeddingOut, self.handle_batch_embedding_out), - (BatchTokenIDOut, self.handle_batch_token_id_out), - ] - ) - def trim_matched_stop( self, output: Union[str, List[int]], finished_reason: Dict, no_stop_trim: bool ): diff --git a/python/sglang/srt/orchestration/std/detokenizer.py b/python/sglang/srt/orchestration/std/detokenizer.py index 39f9cd71a22..4f756fb8457 100644 --- a/python/sglang/srt/orchestration/std/detokenizer.py +++ b/python/sglang/srt/orchestration/std/detokenizer.py @@ -1,10 +1,11 @@ import logging import zmq - from sglang.srt.managers.detokenizer_manager import DetokenizerManager +from sglang.srt.managers.io_struct import BatchTokenIDOut, BatchEmbeddingOut from sglang.srt.server_args import PortArgs from sglang.srt.utils import get_zmq_socket +from sglang.utils import TypeBasedDispatcher logger = logging.getLogger(__name__) @@ -21,3 +22,10 @@ def __init__(self, core: DetokenizerManager, port_args: PortArgs): self.send_to_tokenizer = get_zmq_socket( context, zmq.PUSH, port_args.tokenizer_ipc_name, False ) + + self._request_dispatcher = TypeBasedDispatcher( + [ + (BatchEmbeddingOut, self.core.handle_batch_embedding_out), + (BatchTokenIDOut, self.core.handle_batch_token_id_out), + ] + ) From a68c2e6af4246148b5bcc55802856d6bfc48dd62 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Sat, 25 Jan 2025 10:18:17 +0800 Subject: [PATCH 056/136] mv event loop --- python/sglang/srt/managers/detokenizer_manager.py | 8 -------- python/sglang/srt/orchestration/std/detokenizer.py | 9 +++++++-- 2 files changed, 7 insertions(+), 10 deletions(-) diff --git a/python/sglang/srt/managers/detokenizer_manager.py b/python/sglang/srt/managers/detokenizer_manager.py index 509bc224915..5a95562389d 100644 --- a/python/sglang/srt/managers/detokenizer_manager.py +++ b/python/sglang/srt/managers/detokenizer_manager.py @@ -97,14 +97,6 @@ def trim_matched_stop( return output[:-1] return output - def event_loop(self): - """The event loop that handles requests""" - - while True: - recv_obj = self.recv_from_scheduler.recv_pyobj() - output = self._request_dispatcher(recv_obj) - self.send_to_tokenizer.send_pyobj(output) - def handle_batch_embedding_out(self, recv_obj: BatchEmbeddingOut): # If it is embedding model, no detokenization is needed. return recv_obj diff --git a/python/sglang/srt/orchestration/std/detokenizer.py b/python/sglang/srt/orchestration/std/detokenizer.py index 4f756fb8457..e41ded84cf7 100644 --- a/python/sglang/srt/orchestration/std/detokenizer.py +++ b/python/sglang/srt/orchestration/std/detokenizer.py @@ -16,10 +16,10 @@ def __init__(self, core: DetokenizerManager, port_args: PortArgs): # Init inter-process communication context = zmq.Context(2) - self.recv_from_scheduler = get_zmq_socket( + self._recv_from_scheduler = get_zmq_socket( context, zmq.PULL, port_args.detokenizer_ipc_name, True ) - self.send_to_tokenizer = get_zmq_socket( + self._send_to_tokenizer = get_zmq_socket( context, zmq.PUSH, port_args.tokenizer_ipc_name, False ) @@ -29,3 +29,8 @@ def __init__(self, core: DetokenizerManager, port_args: PortArgs): (BatchTokenIDOut, self.core.handle_batch_token_id_out), ] ) + + def recv_and_process_input_requests(self): + recv_obj = self._recv_from_scheduler.recv_pyobj() + output_obj = self._request_dispatcher(recv_obj) + self._send_to_tokenizer.send_pyobj(output_obj) From f887f2a321c23c0969c9063027bee12fcc591e99 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Sat, 25 Jan 2025 10:18:37 +0800 Subject: [PATCH 057/136] mv run process --- .../srt/managers/detokenizer_manager.py | 16 ------------- .../srt/orchestration/std/detokenizer.py | 23 +++++++++++++++++-- 2 files changed, 21 insertions(+), 18 deletions(-) diff --git a/python/sglang/srt/managers/detokenizer_manager.py b/python/sglang/srt/managers/detokenizer_manager.py index 5a95562389d..09536ca79f1 100644 --- a/python/sglang/srt/managers/detokenizer_manager.py +++ b/python/sglang/srt/managers/detokenizer_manager.py @@ -207,19 +207,3 @@ def __setitem__(self, key, value): # Set the new item super().__setitem__(key, value) - -def run_detokenizer_process( - server_args: ServerArgs, - port_args: PortArgs, -): - setproctitle.setproctitle("sglang::detokenizer") - configure_logger(server_args) - parent_process = psutil.Process().parent() - - try: - manager = DetokenizerManager(server_args, port_args) - manager.event_loop() - except Exception: - traceback = get_exception_traceback() - logger.error(f"DetokenizerManager hit an exception: {traceback}") - parent_process.send_signal(signal.SIGQUIT) diff --git a/python/sglang/srt/orchestration/std/detokenizer.py b/python/sglang/srt/orchestration/std/detokenizer.py index e41ded84cf7..74118635d16 100644 --- a/python/sglang/srt/orchestration/std/detokenizer.py +++ b/python/sglang/srt/orchestration/std/detokenizer.py @@ -1,10 +1,12 @@ import logging +import psutil +import setproctitle import zmq from sglang.srt.managers.detokenizer_manager import DetokenizerManager from sglang.srt.managers.io_struct import BatchTokenIDOut, BatchEmbeddingOut -from sglang.srt.server_args import PortArgs -from sglang.srt.utils import get_zmq_socket +from sglang.srt.server_args import PortArgs, ServerArgs +from sglang.srt.utils import get_zmq_socket, configure_logger from sglang.utils import TypeBasedDispatcher logger = logging.getLogger(__name__) @@ -34,3 +36,20 @@ def recv_and_process_input_requests(self): recv_obj = self._recv_from_scheduler.recv_pyobj() output_obj = self._request_dispatcher(recv_obj) self._send_to_tokenizer.send_pyobj(output_obj) + + +def run_detokenizer_process( + server_args: ServerArgs, + port_args: PortArgs, +): + setproctitle.setproctitle("sglang::detokenizer") + configure_logger(server_args) + parent_process = psutil.Process().parent() + + try: + manager = DetokenizerManager(server_args, port_args) + manager.event_loop() + except Exception: + traceback = get_exception_traceback() + logger.error(f"DetokenizerManager hit an exception: {traceback}") + parent_process.send_signal(signal.SIGQUIT) From 8d808718df1f3b285373bde3a56f6ddb0f19e3e4 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Sat, 25 Jan 2025 10:18:50 +0800 Subject: [PATCH 058/136] update run process --- python/sglang/srt/orchestration/std/detokenizer.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/python/sglang/srt/orchestration/std/detokenizer.py b/python/sglang/srt/orchestration/std/detokenizer.py index 74118635d16..12c2e8712e7 100644 --- a/python/sglang/srt/orchestration/std/detokenizer.py +++ b/python/sglang/srt/orchestration/std/detokenizer.py @@ -1,4 +1,5 @@ import logging +import signal import psutil import setproctitle @@ -7,7 +8,7 @@ from sglang.srt.managers.io_struct import BatchTokenIDOut, BatchEmbeddingOut from sglang.srt.server_args import PortArgs, ServerArgs from sglang.srt.utils import get_zmq_socket, configure_logger -from sglang.utils import TypeBasedDispatcher +from sglang.utils import TypeBasedDispatcher, get_exception_traceback logger = logging.getLogger(__name__) @@ -47,8 +48,10 @@ def run_detokenizer_process( parent_process = psutil.Process().parent() try: - manager = DetokenizerManager(server_args, port_args) - manager.event_loop() + manager = DetokenizerManager(server_args) + communicator = DetokenizerManagerCommunicator(core=manager, port_args=port_args) + while True: + communicator.recv_and_process_input_requests() except Exception: traceback = get_exception_traceback() logger.error(f"DetokenizerManager hit an exception: {traceback}") From 844226d2bb06c5bb26d659e80dc429e447f8b25c Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Sat, 25 Jan 2025 10:19:07 +0800 Subject: [PATCH 059/136] rm unused port_args --- python/sglang/srt/managers/detokenizer_manager.py | 1 - 1 file changed, 1 deletion(-) diff --git a/python/sglang/srt/managers/detokenizer_manager.py b/python/sglang/srt/managers/detokenizer_manager.py index 09536ca79f1..82d7ac73bfb 100644 --- a/python/sglang/srt/managers/detokenizer_manager.py +++ b/python/sglang/srt/managers/detokenizer_manager.py @@ -60,7 +60,6 @@ class DetokenizerManager: def __init__( self, server_args: ServerArgs, - port_args: PortArgs, ): if server_args.skip_tokenizer_init: self.tokenizer = None From 14dd20467d852d3367fa9a8bd7546520debeb924 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Sat, 25 Jan 2025 10:22:19 +0800 Subject: [PATCH 060/136] fmt --- python/sglang/srt/managers/detokenizer_manager.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/python/sglang/srt/managers/detokenizer_manager.py b/python/sglang/srt/managers/detokenizer_manager.py index 13d83f302fe..6599a8b797e 100644 --- a/python/sglang/srt/managers/detokenizer_manager.py +++ b/python/sglang/srt/managers/detokenizer_manager.py @@ -32,7 +32,11 @@ ) from sglang.srt.server_args import PortArgs, ServerArgs from sglang.srt.utils import configure_logger, get_zmq_socket -from sglang.utils import find_printable_text, get_exception_traceback, TypeBasedDispatcher +from sglang.utils import ( + TypeBasedDispatcher, + find_printable_text, + get_exception_traceback, +) logger = logging.getLogger(__name__) @@ -197,7 +201,7 @@ def handle_batch_token_id_out(self, recv_obj: BatchTokenIDOut): s.decoded_text + new_text, recv_obj.finished_reasons[i], recv_obj.no_stop_trim[i], - ) + ) ) return BatchStrOut( @@ -218,7 +222,6 @@ def handle_batch_token_id_out(self, recv_obj: BatchTokenIDOut): ) - class LimitedCapacityDict(OrderedDict): def __init__(self, capacity: int, *args, **kwargs): super().__init__(*args, **kwargs) From 7e2ca8313bcc1e694d795c285b64967156d9b08a Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Sat, 25 Jan 2025 10:23:36 +0800 Subject: [PATCH 061/136] update import --- python/sglang/srt/entrypoints/engine.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/python/sglang/srt/entrypoints/engine.py b/python/sglang/srt/entrypoints/engine.py index e3cb85d71a4..5eaa8a6d441 100644 --- a/python/sglang/srt/entrypoints/engine.py +++ b/python/sglang/srt/entrypoints/engine.py @@ -27,6 +27,8 @@ import threading from typing import AsyncIterator, Dict, Iterator, List, Optional, Tuple, Union +from sglang.srt.orchestration.std.detokenizer import run_detokenizer_process + # Fix a bug of Python threading setattr(threading, "_register_atexit", lambda *args, **kwargs: None) @@ -36,7 +38,6 @@ from sglang.srt.managers.data_parallel_controller import ( run_data_parallel_controller_process, ) -from sglang.srt.managers.detokenizer_manager import run_detokenizer_process from sglang.srt.managers.io_struct import ( EmbeddingReqInput, GenerateReqInput, From edb751cf49e9268af67ef5d0c152bd715b0cb194 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Sat, 25 Jan 2025 10:24:07 +0800 Subject: [PATCH 062/136] mv run_scheduler_process --- python/sglang/srt/managers/scheduler.py | 46 ------ .../sglang/srt/orchestration/std/scheduler.py | 135 ++++++++++++++++++ 2 files changed, 135 insertions(+), 46 deletions(-) create mode 100644 python/sglang/srt/orchestration/std/scheduler.py diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index 80d63be7b65..7e12042a670 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -1723,49 +1723,3 @@ def _import_static_state(model, static_params): self_named_buffers[name][...] = tensor -def run_scheduler_process( - server_args: ServerArgs, - port_args: PortArgs, - gpu_id: int, - tp_rank: int, - dp_rank: Optional[int], - pipe_writer, -): - setproctitle.setproctitle("sglang::scheduler") - faulthandler.enable() - - # [For Router] if env var "SGLANG_DP_RANK" exist, set dp_rank to the value of the env var - if dp_rank is None and "SGLANG_DP_RANK" in os.environ: - dp_rank = int(os.environ["SGLANG_DP_RANK"]) - - # Configue the logger - if dp_rank is None: - configure_logger(server_args, prefix=f" TP{tp_rank}") - else: - configure_logger(server_args, prefix=f" DP{dp_rank} TP{tp_rank}") - suppress_other_loggers() - - # Set cpu affinity to this gpu process - if get_bool_env_var("SGLANG_SET_CPU_AFFINITY"): - set_gpu_proc_affinity(server_args.tp_size, server_args.nnodes, gpu_id) - - parent_process = psutil.Process().parent() - - # Create a scheduler and run the event loop - try: - scheduler = Scheduler(server_args, port_args, gpu_id, tp_rank, dp_rank) - pipe_writer.send( - { - "status": "ready", - "max_total_num_tokens": scheduler.max_total_num_tokens, - "max_req_input_len": scheduler.max_req_input_len, - } - ) - if scheduler.enable_overlap: - scheduler.event_loop_overlap() - else: - scheduler.event_loop_normal() - except Exception: - traceback = get_exception_traceback() - logger.error(f"Scheduler hit an exception: {traceback}") - parent_process.send_signal(signal.SIGQUIT) diff --git a/python/sglang/srt/orchestration/std/scheduler.py b/python/sglang/srt/orchestration/std/scheduler.py new file mode 100644 index 00000000000..df484b78a52 --- /dev/null +++ b/python/sglang/srt/orchestration/std/scheduler.py @@ -0,0 +1,135 @@ +import faulthandler +import logging +import os +import signal +import threading +import time +import warnings +from collections import deque +from concurrent import futures +from dataclasses import dataclass +from http import HTTPStatus +from types import SimpleNamespace +from typing import Dict, List, Optional, Tuple, Union + +import psutil +import setproctitle +import torch +import zmq + +from sglang.global_config import global_config +from sglang.srt.configs.model_config import ModelConfig +from sglang.srt.constrained.base_grammar_backend import create_grammar_backend +from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer +from sglang.srt.layers.dp_attention import compute_dp_attention_world_info +from sglang.srt.layers.logits_processor import LogitsProcessorOutput +from sglang.srt.managers.io_struct import ( + AbortReq, + BatchEmbeddingOut, + BatchTokenIDOut, + CloseSessionReqInput, + FlushCacheReq, + GetWeightsByNameReqInput, + GetWeightsByNameReqOutput, + InitWeightsUpdateGroupReqInput, + InitWeightsUpdateGroupReqOutput, + OpenSessionReqInput, + OpenSessionReqOutput, + ProfileReq, + ReleaseMemoryOccupationReqInput, + ReleaseMemoryOccupationReqOutput, + ResumeMemoryOccupationReqInput, + ResumeMemoryOccupationReqOutput, + TokenizedEmbeddingReqInput, + TokenizedGenerateReqInput, + UpdateWeightFromDiskReqInput, + UpdateWeightFromDiskReqOutput, + UpdateWeightsFromDistributedReqInput, + UpdateWeightsFromDistributedReqOutput, + UpdateWeightsFromTensorReqInput, + UpdateWeightsFromTensorReqOutput, +) +from sglang.srt.managers.schedule_batch import ( + FINISH_ABORT, + BaseFinishReason, + ImageInputs, + Req, + ScheduleBatch, + global_server_args_dict, +) +from sglang.srt.managers.schedule_policy import ( + AddReqResult, + PrefillAdder, + SchedulePolicy, +) +from sglang.srt.managers.session_controller import Session +from sglang.srt.managers.tp_worker import TpModelWorker +from sglang.srt.managers.tp_worker_overlap_thread import TpModelWorkerClient +from sglang.srt.managers.utils import validate_input_length +from sglang.srt.mem_cache.chunk_cache import ChunkCache +from sglang.srt.mem_cache.radix_cache import RadixCache +from sglang.srt.metrics.collector import SchedulerMetricsCollector, SchedulerStats +from sglang.srt.model_executor.forward_batch_info import ForwardMode +from sglang.srt.server_args import PortArgs, ServerArgs +from sglang.srt.speculative.spec_info import SpeculativeAlgorithm +from sglang.srt.torch_memory_saver_adapter import TorchMemorySaverAdapter +from sglang.srt.utils import ( + broadcast_pyobj, + configure_logger, + crash_on_warnings, + get_bool_env_var, + get_zmq_socket, + set_gpu_proc_affinity, + set_random_seed, + suppress_other_loggers, +) +from sglang.utils import TypeBasedDispatcher, get_exception_traceback + +logger = logging.getLogger(__name__) + +def run_scheduler_process( + server_args: ServerArgs, + port_args: PortArgs, + gpu_id: int, + tp_rank: int, + dp_rank: Optional[int], + pipe_writer, +): + setproctitle.setproctitle("sglang::scheduler") + faulthandler.enable() + + # [For Router] if env var "SGLANG_DP_RANK" exist, set dp_rank to the value of the env var + if dp_rank is None and "SGLANG_DP_RANK" in os.environ: + dp_rank = int(os.environ["SGLANG_DP_RANK"]) + + # Configue the logger + if dp_rank is None: + configure_logger(server_args, prefix=f" TP{tp_rank}") + else: + configure_logger(server_args, prefix=f" DP{dp_rank} TP{tp_rank}") + suppress_other_loggers() + + # Set cpu affinity to this gpu process + if get_bool_env_var("SGLANG_SET_CPU_AFFINITY"): + set_gpu_proc_affinity(server_args.tp_size, server_args.nnodes, gpu_id) + + parent_process = psutil.Process().parent() + + # Create a scheduler and run the event loop + try: + scheduler = Scheduler(server_args, port_args, gpu_id, tp_rank, dp_rank) + pipe_writer.send( + { + "status": "ready", + "max_total_num_tokens": scheduler.max_total_num_tokens, + "max_req_input_len": scheduler.max_req_input_len, + } + ) + if scheduler.enable_overlap: + scheduler.event_loop_overlap() + else: + scheduler.event_loop_normal() + except Exception: + traceback = get_exception_traceback() + logger.error(f"Scheduler hit an exception: {traceback}") + parent_process.send_signal(signal.SIGQUIT) From 3939f961a55a415ca2fecb549062cd612a896963 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Sat, 25 Jan 2025 10:24:29 +0800 Subject: [PATCH 063/136] empty class --- .../sglang/srt/orchestration/std/scheduler.py | 51 ++++++------------- 1 file changed, 15 insertions(+), 36 deletions(-) diff --git a/python/sglang/srt/orchestration/std/scheduler.py b/python/sglang/srt/orchestration/std/scheduler.py index df484b78a52..1800dfd30cd 100644 --- a/python/sglang/srt/orchestration/std/scheduler.py +++ b/python/sglang/srt/orchestration/std/scheduler.py @@ -23,47 +23,13 @@ from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer from sglang.srt.layers.dp_attention import compute_dp_attention_world_info from sglang.srt.layers.logits_processor import LogitsProcessorOutput -from sglang.srt.managers.io_struct import ( - AbortReq, - BatchEmbeddingOut, - BatchTokenIDOut, - CloseSessionReqInput, - FlushCacheReq, - GetWeightsByNameReqInput, - GetWeightsByNameReqOutput, - InitWeightsUpdateGroupReqInput, - InitWeightsUpdateGroupReqOutput, - OpenSessionReqInput, - OpenSessionReqOutput, - ProfileReq, - ReleaseMemoryOccupationReqInput, - ReleaseMemoryOccupationReqOutput, - ResumeMemoryOccupationReqInput, - ResumeMemoryOccupationReqOutput, - TokenizedEmbeddingReqInput, - TokenizedGenerateReqInput, - UpdateWeightFromDiskReqInput, - UpdateWeightFromDiskReqOutput, - UpdateWeightsFromDistributedReqInput, - UpdateWeightsFromDistributedReqOutput, - UpdateWeightsFromTensorReqInput, - UpdateWeightsFromTensorReqOutput, -) -from sglang.srt.managers.schedule_batch import ( - FINISH_ABORT, - BaseFinishReason, - ImageInputs, - Req, - ScheduleBatch, - global_server_args_dict, -) from sglang.srt.managers.schedule_policy import ( AddReqResult, PrefillAdder, SchedulePolicy, ) +from sglang.srt.managers.scheduler import Scheduler from sglang.srt.managers.session_controller import Session -from sglang.srt.managers.tp_worker import TpModelWorker from sglang.srt.managers.tp_worker_overlap_thread import TpModelWorkerClient from sglang.srt.managers.utils import validate_input_length from sglang.srt.mem_cache.chunk_cache import ChunkCache @@ -76,7 +42,6 @@ from sglang.srt.utils import ( broadcast_pyobj, configure_logger, - crash_on_warnings, get_bool_env_var, get_zmq_socket, set_gpu_proc_affinity, @@ -87,6 +52,20 @@ logger = logging.getLogger(__name__) + +class SchedulerCommunicator: + def __init__( + self, + core: Scheduler, + server_args: ServerArgs, + port_args: PortArgs, + tp_rank: int, + ): + self.core = core + self.server_args = server_args + self.tp_rank = tp_rank + + def run_scheduler_process( server_args: ServerArgs, port_args: PortArgs, From 49908f18d19e41259c86df7f31aa4acacf95464e Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Sat, 25 Jan 2025 10:25:41 +0800 Subject: [PATCH 064/136] mv scheduler.init --- python/sglang/srt/managers/scheduler.py | 54 ----------- .../sglang/srt/orchestration/std/scheduler.py | 89 +++++++++++++------ 2 files changed, 61 insertions(+), 82 deletions(-) diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index 7e12042a670..a9f24565ae7 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -161,31 +161,6 @@ def __init__( ) ) - # Init inter-process communication - context = zmq.Context(2) - if self.attn_tp_rank == 0: - self.recv_from_tokenizer = get_zmq_socket( - context, zmq.PULL, port_args.scheduler_input_ipc_name, False - ) - self.send_to_tokenizer = get_zmq_socket( - context, zmq.PUSH, port_args.tokenizer_ipc_name, False - ) - - if server_args.skip_tokenizer_init: - # Directly send to the StdOrchestrator - self.send_to_detokenizer = get_zmq_socket( - context, zmq.PUSH, port_args.tokenizer_ipc_name, False - ) - else: - # Send to the DetokenizerManager - self.send_to_detokenizer = get_zmq_socket( - context, zmq.PUSH, port_args.detokenizer_ipc_name, False - ) - else: - self.recv_from_tokenizer = None - self.send_to_tokenizer = SimpleNamespace(send_pyobj=lambda x: None) - self.send_to_detokenizer = SimpleNamespace(send_pyobj=lambda x: None) - # Init tokenizer self.model_config = ModelConfig( server_args.model_path, @@ -408,35 +383,6 @@ def __init__( }, ) - # Init request dispatcher - self._request_dispatcher = TypeBasedDispatcher( - [ - (TokenizedGenerateReqInput, self.handle_generate_request), - (TokenizedEmbeddingReqInput, self.handle_embedding_request), - (FlushCacheReq, self.flush_cache_wrapped), - (AbortReq, self.abort_request), - (UpdateWeightFromDiskReqInput, self.update_weights_from_disk), - (InitWeightsUpdateGroupReqInput, self.init_weights_update_group), - ( - UpdateWeightsFromDistributedReqInput, - self.update_weights_from_distributed, - ), - (UpdateWeightsFromTensorReqInput, self.update_weights_from_tensor), - (GetWeightsByNameReqInput, self.get_weights_by_name), - (ProfileReq, self.profile), - (OpenSessionReqInput, self.open_session), - (CloseSessionReqInput, self.close_session), - ( - ReleaseMemoryOccupationReqInput, - lambda _: self.release_memory_occupation(), - ), - ( - ResumeMemoryOccupationReqInput, - lambda _: self.resume_memory_occupation(), - ), - ] - ) - def watchdog_thread(self): """A watch dog thread that will try to kill the server itself if one batch takes too long.""" self.watchdog_last_forward_ct = 0 diff --git a/python/sglang/srt/orchestration/std/scheduler.py b/python/sglang/srt/orchestration/std/scheduler.py index 1800dfd30cd..e2e2f960c23 100644 --- a/python/sglang/srt/orchestration/std/scheduler.py +++ b/python/sglang/srt/orchestration/std/scheduler.py @@ -2,50 +2,27 @@ import logging import os import signal -import threading import time -import warnings -from collections import deque -from concurrent import futures -from dataclasses import dataclass -from http import HTTPStatus from types import SimpleNamespace -from typing import Dict, List, Optional, Tuple, Union +from typing import Optional import psutil import setproctitle -import torch import zmq -from sglang.global_config import global_config -from sglang.srt.configs.model_config import ModelConfig -from sglang.srt.constrained.base_grammar_backend import create_grammar_backend -from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer -from sglang.srt.layers.dp_attention import compute_dp_attention_world_info -from sglang.srt.layers.logits_processor import LogitsProcessorOutput -from sglang.srt.managers.schedule_policy import ( - AddReqResult, - PrefillAdder, - SchedulePolicy, -) +from sglang.srt.managers.io_struct import TokenizedGenerateReqInput, TokenizedEmbeddingReqInput, FlushCacheReq, \ + AbortReq, UpdateWeightFromDiskReqInput, InitWeightsUpdateGroupReqInput, UpdateWeightsFromDistributedReqInput, \ + UpdateWeightsFromTensorReqInput, GetWeightsByNameReqInput, ProfileReq, OpenSessionReqInput, CloseSessionReqInput, \ + ReleaseMemoryOccupationReqInput, ResumeMemoryOccupationReqInput from sglang.srt.managers.scheduler import Scheduler -from sglang.srt.managers.session_controller import Session from sglang.srt.managers.tp_worker_overlap_thread import TpModelWorkerClient -from sglang.srt.managers.utils import validate_input_length -from sglang.srt.mem_cache.chunk_cache import ChunkCache from sglang.srt.mem_cache.radix_cache import RadixCache -from sglang.srt.metrics.collector import SchedulerMetricsCollector, SchedulerStats -from sglang.srt.model_executor.forward_batch_info import ForwardMode from sglang.srt.server_args import PortArgs, ServerArgs -from sglang.srt.speculative.spec_info import SpeculativeAlgorithm -from sglang.srt.torch_memory_saver_adapter import TorchMemorySaverAdapter from sglang.srt.utils import ( - broadcast_pyobj, configure_logger, get_bool_env_var, get_zmq_socket, set_gpu_proc_affinity, - set_random_seed, suppress_other_loggers, ) from sglang.utils import TypeBasedDispatcher, get_exception_traceback @@ -65,6 +42,62 @@ def __init__( self.server_args = server_args self.tp_rank = tp_rank + # Init inter-process communication + context = zmq.Context(2) + if self.attn_tp_rank == 0: + self.recv_from_tokenizer = get_zmq_socket( + context, zmq.PULL, port_args.scheduler_input_ipc_name, False + ) + self.send_to_tokenizer = get_zmq_socket( + context, zmq.PUSH, port_args.tokenizer_ipc_name, False + ) + + if server_args.skip_tokenizer_init: + # Directly send to the StdOrchestrator + self.send_to_detokenizer = get_zmq_socket( + context, zmq.PUSH, port_args.tokenizer_ipc_name, False + ) + else: + # Send to the DetokenizerManager + self.send_to_detokenizer = get_zmq_socket( + context, zmq.PUSH, port_args.detokenizer_ipc_name, False + ) + else: + self.recv_from_tokenizer = None + self.send_to_tokenizer = SimpleNamespace(send_pyobj=lambda x: None) + self.send_to_detokenizer = SimpleNamespace(send_pyobj=lambda x: None) + + # Init request dispatcher + self._request_dispatcher = TypeBasedDispatcher( + [ + (TokenizedGenerateReqInput, self.handle_generate_request), + (TokenizedEmbeddingReqInput, self.handle_embedding_request), + (FlushCacheReq, self.flush_cache_wrapped), + (AbortReq, self.abort_request), + (UpdateWeightFromDiskReqInput, self.update_weights_from_disk), + (InitWeightsUpdateGroupReqInput, self.init_weights_update_group), + ( + UpdateWeightsFromDistributedReqInput, + self.update_weights_from_distributed, + ), + (UpdateWeightsFromTensorReqInput, self.update_weights_from_tensor), + (GetWeightsByNameReqInput, self.get_weights_by_name), + (ProfileReq, self.profile), + (OpenSessionReqInput, self.open_session), + (CloseSessionReqInput, self.close_session), + ( + ReleaseMemoryOccupationReqInput, + lambda _: self.release_memory_occupation(), + ), + ( + ResumeMemoryOccupationReqInput, + lambda _: self.resume_memory_occupation(), + ), + ] + ) + + core.on_generation_output = self._handle_generation_output + def run_scheduler_process( server_args: ServerArgs, From 8ff66b762f1d6c6e5390bda99f178e43e099d833 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Sat, 25 Jan 2025 10:26:11 +0800 Subject: [PATCH 065/136] update func call --- .../sglang/srt/orchestration/std/scheduler.py | 32 ++++++++----------- 1 file changed, 14 insertions(+), 18 deletions(-) diff --git a/python/sglang/srt/orchestration/std/scheduler.py b/python/sglang/srt/orchestration/std/scheduler.py index e2e2f960c23..1c7328a0831 100644 --- a/python/sglang/srt/orchestration/std/scheduler.py +++ b/python/sglang/srt/orchestration/std/scheduler.py @@ -2,21 +2,17 @@ import logging import os import signal -import time from types import SimpleNamespace from typing import Optional import psutil import setproctitle import zmq - from sglang.srt.managers.io_struct import TokenizedGenerateReqInput, TokenizedEmbeddingReqInput, FlushCacheReq, \ AbortReq, UpdateWeightFromDiskReqInput, InitWeightsUpdateGroupReqInput, UpdateWeightsFromDistributedReqInput, \ UpdateWeightsFromTensorReqInput, GetWeightsByNameReqInput, ProfileReq, OpenSessionReqInput, CloseSessionReqInput, \ ReleaseMemoryOccupationReqInput, ResumeMemoryOccupationReqInput from sglang.srt.managers.scheduler import Scheduler -from sglang.srt.managers.tp_worker_overlap_thread import TpModelWorkerClient -from sglang.srt.mem_cache.radix_cache import RadixCache from sglang.srt.server_args import PortArgs, ServerArgs from sglang.srt.utils import ( configure_logger, @@ -70,28 +66,28 @@ def __init__( # Init request dispatcher self._request_dispatcher = TypeBasedDispatcher( [ - (TokenizedGenerateReqInput, self.handle_generate_request), - (TokenizedEmbeddingReqInput, self.handle_embedding_request), - (FlushCacheReq, self.flush_cache_wrapped), - (AbortReq, self.abort_request), - (UpdateWeightFromDiskReqInput, self.update_weights_from_disk), - (InitWeightsUpdateGroupReqInput, self.init_weights_update_group), + (TokenizedGenerateReqInput, self.core.handle_generate_request), + (TokenizedEmbeddingReqInput, self.core.handle_embedding_request), + (FlushCacheReq, self.core.flush_cache_wrapped), + (AbortReq, self.core.abort_request), + (UpdateWeightFromDiskReqInput, self.core.update_weights_from_disk), + (InitWeightsUpdateGroupReqInput, self.core.init_weights_update_group), ( UpdateWeightsFromDistributedReqInput, - self.update_weights_from_distributed, + self.core.update_weights_from_distributed, ), - (UpdateWeightsFromTensorReqInput, self.update_weights_from_tensor), - (GetWeightsByNameReqInput, self.get_weights_by_name), - (ProfileReq, self.profile), - (OpenSessionReqInput, self.open_session), - (CloseSessionReqInput, self.close_session), + (UpdateWeightsFromTensorReqInput, self.core.update_weights_from_tensor), + (GetWeightsByNameReqInput, self.core.get_weights_by_name), + (ProfileReq, self.core.profile), + (OpenSessionReqInput, self.core.open_session), + (CloseSessionReqInput, self.core.close_session), ( ReleaseMemoryOccupationReqInput, - lambda _: self.release_memory_occupation(), + lambda _: self.core.release_memory_occupation(), ), ( ResumeMemoryOccupationReqInput, - lambda _: self.resume_memory_occupation(), + lambda _: self.core.resume_memory_occupation(), ), ] ) From de3ee311c2b6393eb82c273ddb769f0d59e1fb02 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Sat, 25 Jan 2025 10:26:56 +0800 Subject: [PATCH 066/136] field --- python/sglang/srt/managers/scheduler.py | 70 +++++++++++-------------- 1 file changed, 30 insertions(+), 40 deletions(-) diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index a9f24565ae7..5e61419b9ac 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -13,7 +13,6 @@ # ============================================================================== """A scheduler that manages a tensor parallel GPU worker.""" -import faulthandler import logging import os import signal @@ -24,11 +23,9 @@ from concurrent import futures from dataclasses import dataclass from http import HTTPStatus -from types import SimpleNamespace -from typing import Dict, List, Optional, Tuple, Union +from typing import Dict, List, Optional, Tuple, Union, Callable import psutil -import setproctitle import torch import zmq @@ -51,7 +48,6 @@ OpenSessionReqInput, OpenSessionReqOutput, ProfileReq, - ReleaseMemoryOccupationReqInput, ReleaseMemoryOccupationReqOutput, ResumeMemoryOccupationReqInput, ResumeMemoryOccupationReqOutput, @@ -90,15 +86,10 @@ from sglang.srt.torch_memory_saver_adapter import TorchMemorySaverAdapter from sglang.srt.utils import ( broadcast_pyobj, - configure_logger, crash_on_warnings, get_bool_env_var, - get_zmq_socket, - set_gpu_proc_affinity, set_random_seed, - suppress_other_loggers, ) -from sglang.utils import TypeBasedDispatcher, get_exception_traceback logger = logging.getLogger(__name__) @@ -149,6 +140,7 @@ def __init__( if not self.spec_algorithm.is_none() else 1 ) + self.on_generation_output: Optional[Callable] = None # Distributed rank info self.dp_size = server_args.dp_size @@ -337,8 +329,8 @@ def __init__( 1.0, ) self.new_token_ratio_decay = ( - self.init_new_token_ratio - self.min_new_token_ratio - ) / global_config.default_new_token_ratio_decay_steps + self.init_new_token_ratio - self.min_new_token_ratio + ) / global_config.default_new_token_ratio_decay_steps self.new_token_ratio = self.init_new_token_ratio # Tells whether the current running batch is full so that we can skip @@ -681,9 +673,9 @@ def handle_embedding_request( def log_prefill_stats(self, adder, can_run_list, running_bs, has_being_chunked): self.tree_cache_metrics["total"] += ( - adder.log_input_tokens + adder.log_hit_tokens - ) / 10**9 - self.tree_cache_metrics["hit"] += (adder.log_hit_tokens) / 10**9 + adder.log_input_tokens + adder.log_hit_tokens + ) / 10 ** 9 + self.tree_cache_metrics["hit"] += (adder.log_hit_tokens) / 10 ** 9 tree_cache_hit_rate = ( self.tree_cache_metrics["hit"] / self.tree_cache_metrics["total"] ) @@ -862,10 +854,10 @@ def get_new_batch_prefill(self) -> Optional[ScheduleBatch]: if ( self.lora_paths and len( - lora_set - | set([req.lora_path for req in adder.can_run_list]) - | set([req.lora_path]) - ) + lora_set + | set([req.lora_path for req in adder.can_run_list]) + | set([req.lora_path]) + ) > self.max_loras_per_batch ): self.batch_is_full = True @@ -1085,7 +1077,7 @@ def process_batch_result_prefill( if self.is_mixed_chunk and self.enable_overlap and req.finished(): # Free the one delayed token for the mixed decode batch j = len(batch.out_cache_loc) - len(batch.reqs) + i - self.token_to_kv_pool.free(batch.out_cache_loc[j : j + 1]) + self.token_to_kv_pool.free(batch.out_cache_loc[j: j + 1]) continue if req.is_being_chunked <= 0: @@ -1172,7 +1164,7 @@ def process_batch_result_decode( if self.enable_overlap and req.finished(): # Free the one delayed token - self.token_to_kv_pool.free(batch.out_cache_loc[i : i + 1]) + self.token_to_kv_pool.free(batch.out_cache_loc[i: i + 1]) continue if batch.spec_algorithm.is_none(): @@ -1232,15 +1224,15 @@ def add_logprob_return_values( if req.input_token_logprobs_val is None: input_token_logprobs_val = output.input_token_logprobs[ - pt : pt + num_input_logprobs - 1 - req.last_update_decode_tokens - ] + pt: pt + num_input_logprobs - 1 - req.last_update_decode_tokens + ] input_token_logprobs_idx = req.fill_ids[ - len(req.fill_ids) - - num_input_logprobs - + 1 : len(req.fill_ids) - - req.last_update_decode_tokens - ] + len(req.fill_ids) + - num_input_logprobs + + 1: len(req.fill_ids) + - req.last_update_decode_tokens + ] # Clip the padded hash values from image tokens. # Otherwise, it will lead to detokenization errors. input_token_logprobs_idx = [ @@ -1261,18 +1253,18 @@ def add_logprob_return_values( # Some decode tokens are re-computed in an extend batch req.output_token_logprobs_val.extend( output.input_token_logprobs[ - pt - + num_input_logprobs - - 1 - - req.last_update_decode_tokens : pt - + num_input_logprobs - - 1 + pt + + num_input_logprobs + - 1 + - req.last_update_decode_tokens: pt + + num_input_logprobs + - 1 ], ) req.output_token_logprobs_idx.extend( req.fill_ids[ - len(req.fill_ids) - - req.last_update_decode_tokens : len(req.fill_ids) + len(req.fill_ids) + - req.last_update_decode_tokens: len(req.fill_ids) ] ) @@ -1286,10 +1278,10 @@ def add_logprob_return_values( if req.last_update_decode_tokens != 0: req.output_top_logprobs_val.extend( - output.input_top_logprobs_val[i][-req.last_update_decode_tokens :] + output.input_top_logprobs_val[i][-req.last_update_decode_tokens:] ) req.output_top_logprobs_idx.extend( - output.input_top_logprobs_idx[i][-req.last_update_decode_tokens :] + output.input_top_logprobs_idx[i][-req.last_update_decode_tokens:] ) req.output_top_logprobs_val.append(output.next_token_top_logprobs_val[i]) @@ -1667,5 +1659,3 @@ def _import_static_state(model, static_params): self_named_buffers = dict(model.named_buffers()) for name, tensor in static_params["buffers"]: self_named_buffers[name][...] = tensor - - From cddd94d46f2a8b921684f80835be1d37a4d3262f Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Sat, 25 Jan 2025 10:28:34 +0800 Subject: [PATCH 067/136] make private --- .../sglang/srt/orchestration/std/scheduler.py | 17 ++++++++--------- 1 file changed, 8 insertions(+), 9 deletions(-) diff --git a/python/sglang/srt/orchestration/std/scheduler.py b/python/sglang/srt/orchestration/std/scheduler.py index 1c7328a0831..e3eaa4fa428 100644 --- a/python/sglang/srt/orchestration/std/scheduler.py +++ b/python/sglang/srt/orchestration/std/scheduler.py @@ -19,8 +19,7 @@ get_bool_env_var, get_zmq_socket, set_gpu_proc_affinity, - suppress_other_loggers, -) + suppress_other_loggers, ) from sglang.utils import TypeBasedDispatcher, get_exception_traceback logger = logging.getLogger(__name__) @@ -41,27 +40,27 @@ def __init__( # Init inter-process communication context = zmq.Context(2) if self.attn_tp_rank == 0: - self.recv_from_tokenizer = get_zmq_socket( + self._recv_from_tokenizer = get_zmq_socket( context, zmq.PULL, port_args.scheduler_input_ipc_name, False ) - self.send_to_tokenizer = get_zmq_socket( + self._send_to_tokenizer = get_zmq_socket( context, zmq.PUSH, port_args.tokenizer_ipc_name, False ) if server_args.skip_tokenizer_init: # Directly send to the StdOrchestrator - self.send_to_detokenizer = get_zmq_socket( + self._send_to_detokenizer = get_zmq_socket( context, zmq.PUSH, port_args.tokenizer_ipc_name, False ) else: # Send to the DetokenizerManager - self.send_to_detokenizer = get_zmq_socket( + self._send_to_detokenizer = get_zmq_socket( context, zmq.PUSH, port_args.detokenizer_ipc_name, False ) else: - self.recv_from_tokenizer = None - self.send_to_tokenizer = SimpleNamespace(send_pyobj=lambda x: None) - self.send_to_detokenizer = SimpleNamespace(send_pyobj=lambda x: None) + self._recv_from_tokenizer = None + self._send_to_tokenizer = SimpleNamespace(send_pyobj=lambda x: None) + self._send_to_detokenizer = SimpleNamespace(send_pyobj=lambda x: None) # Init request dispatcher self._request_dispatcher = TypeBasedDispatcher( From 6bc0090831f16086ccbec03dd6f4a1cddd613a9d Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Sat, 25 Jan 2025 10:28:47 +0800 Subject: [PATCH 068/136] mv _process_input_requests --- python/sglang/srt/managers/scheduler.py | 8 -------- python/sglang/srt/orchestration/std/scheduler.py | 8 +++++++- 2 files changed, 7 insertions(+), 9 deletions(-) diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index 5e61419b9ac..d917e59d272 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -28,7 +28,6 @@ import psutil import torch import zmq - from sglang.global_config import global_config from sglang.srt.configs.model_config import ModelConfig from sglang.srt.constrained.base_grammar_backend import create_grammar_backend @@ -49,7 +48,6 @@ OpenSessionReqOutput, ProfileReq, ReleaseMemoryOccupationReqOutput, - ResumeMemoryOccupationReqInput, ResumeMemoryOccupationReqOutput, TokenizedEmbeddingReqInput, TokenizedGenerateReqInput, @@ -506,12 +504,6 @@ def recv_requests(self) -> List[Req]: recv_reqs = broadcast_pyobj(recv_reqs, self.tp_rank, self.tp_cpu_group) return recv_reqs - def process_input_requests(self, recv_reqs: List): - for recv_req in recv_reqs: - output = self._request_dispatcher(recv_req) - if output is not None: - self.send_to_tokenizer.send_pyobj(output) - def handle_generate_request( self, recv_req: TokenizedGenerateReqInput, diff --git a/python/sglang/srt/orchestration/std/scheduler.py b/python/sglang/srt/orchestration/std/scheduler.py index e3eaa4fa428..ab55508de23 100644 --- a/python/sglang/srt/orchestration/std/scheduler.py +++ b/python/sglang/srt/orchestration/std/scheduler.py @@ -3,7 +3,7 @@ import os import signal from types import SimpleNamespace -from typing import Optional +from typing import Optional, List import psutil import setproctitle @@ -93,6 +93,12 @@ def __init__( core.on_generation_output = self._handle_generation_output + def _process_input_requests(self, recv_reqs: List): + for recv_req in recv_reqs: + output = self._request_dispatcher(recv_req) + if output is not None: + self._send_to_tokenizer.send_pyobj(output) + def run_scheduler_process( server_args: ServerArgs, From c9afaefc7a9c9c5d775bc6d576fd5519ff661c23 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Sat, 25 Jan 2025 10:29:04 +0800 Subject: [PATCH 069/136] call on_generation_output --- python/sglang/srt/managers/scheduler.py | 4 ++-- python/sglang/srt/orchestration/std/scheduler.py | 3 +++ 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index d917e59d272..c7e509f016d 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -1366,7 +1366,7 @@ def stream_output( # Send to detokenizer if rids: - self.send_to_detokenizer.send_pyobj( + self.on_generation_output( BatchTokenIDOut( rids, finished_reasons, @@ -1400,7 +1400,7 @@ def stream_output( finished_reasons.append(req.finished_reason.to_json()) embeddings.append(req.embedding) prompt_tokens.append(len(req.origin_input_ids)) - self.send_to_detokenizer.send_pyobj( + self.on_generation_output( BatchEmbeddingOut(rids, finished_reasons, embeddings, prompt_tokens) ) diff --git a/python/sglang/srt/orchestration/std/scheduler.py b/python/sglang/srt/orchestration/std/scheduler.py index ab55508de23..ff919c56961 100644 --- a/python/sglang/srt/orchestration/std/scheduler.py +++ b/python/sglang/srt/orchestration/std/scheduler.py @@ -99,6 +99,9 @@ def _process_input_requests(self, recv_reqs: List): if output is not None: self._send_to_tokenizer.send_pyobj(output) + def _handle_generation_output(self, obj): + self._send_to_detokenizer.send_pyobj(obj) + def run_scheduler_process( server_args: ServerArgs, From 4cd5ec168cc9e53901a689d2133682dc72a00ad0 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Sat, 25 Jan 2025 10:29:39 +0800 Subject: [PATCH 070/136] mv recv_requests --- python/sglang/srt/managers/scheduler.py | 53 ----------------- .../sglang/srt/orchestration/std/scheduler.py | 57 ++++++++++++++++++- 2 files changed, 56 insertions(+), 54 deletions(-) diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index c7e509f016d..883daf60c1c 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -27,7 +27,6 @@ import psutil import torch -import zmq from sglang.global_config import global_config from sglang.srt.configs.model_config import ModelConfig from sglang.srt.constrained.base_grammar_backend import create_grammar_backend @@ -83,7 +82,6 @@ from sglang.srt.speculative.spec_info import SpeculativeAlgorithm from sglang.srt.torch_memory_saver_adapter import TorchMemorySaverAdapter from sglang.srt.utils import ( - broadcast_pyobj, crash_on_warnings, get_bool_env_var, set_random_seed, @@ -453,57 +451,6 @@ def event_loop_overlap(self): self.last_batch = batch - def recv_requests(self) -> List[Req]: - """Receive results at tp_rank = 0 and broadcast it to all other TP ranks.""" - if self.attn_tp_rank == 0: - recv_reqs = [] - - while True: - try: - recv_req = self.recv_from_tokenizer.recv_pyobj(zmq.NOBLOCK) - except zmq.ZMQError: - break - recv_reqs.append(recv_req) - else: - recv_reqs = None - - if self.server_args.enable_dp_attention: - if self.attn_tp_rank == 0: - work_reqs = [ - req - for req in recv_reqs - if isinstance( - req, (TokenizedGenerateReqInput, TokenizedEmbeddingReqInput) - ) - ] - control_reqs = [ - req - for req in recv_reqs - if not isinstance( - req, (TokenizedGenerateReqInput, TokenizedEmbeddingReqInput) - ) - ] - else: - work_reqs = None - control_reqs = None - - if self.attn_tp_size != 1: - attn_tp_rank_0 = self.dp_rank * self.attn_tp_size - work_reqs = broadcast_pyobj( - work_reqs, - self.attn_tp_rank, - self.attn_tp_cpu_group, - src=attn_tp_rank_0, - ) - if self.tp_size != 1: - control_reqs = broadcast_pyobj( - control_reqs, self.tp_rank, self.tp_cpu_group - ) - recv_reqs = work_reqs + control_reqs - elif self.tp_size != 1: - recv_reqs = broadcast_pyobj(recv_reqs, self.tp_rank, self.tp_cpu_group) - return recv_reqs - def handle_generate_request( self, recv_req: TokenizedGenerateReqInput, diff --git a/python/sglang/srt/orchestration/std/scheduler.py b/python/sglang/srt/orchestration/std/scheduler.py index ff919c56961..6dfaf9deb59 100644 --- a/python/sglang/srt/orchestration/std/scheduler.py +++ b/python/sglang/srt/orchestration/std/scheduler.py @@ -12,6 +12,7 @@ AbortReq, UpdateWeightFromDiskReqInput, InitWeightsUpdateGroupReqInput, UpdateWeightsFromDistributedReqInput, \ UpdateWeightsFromTensorReqInput, GetWeightsByNameReqInput, ProfileReq, OpenSessionReqInput, CloseSessionReqInput, \ ReleaseMemoryOccupationReqInput, ResumeMemoryOccupationReqInput +from sglang.srt.managers.schedule_batch import Req from sglang.srt.managers.scheduler import Scheduler from sglang.srt.server_args import PortArgs, ServerArgs from sglang.srt.utils import ( @@ -19,7 +20,7 @@ get_bool_env_var, get_zmq_socket, set_gpu_proc_affinity, - suppress_other_loggers, ) + suppress_other_loggers, broadcast_pyobj, ) from sglang.utils import TypeBasedDispatcher, get_exception_traceback logger = logging.getLogger(__name__) @@ -93,6 +94,60 @@ def __init__( core.on_generation_output = self._handle_generation_output + def recv_and_process_input_requests(self): + self._process_input_requests(self._recv_requests()) + + def _recv_requests(self) -> List[Req]: + """Receive results at tp_rank = 0 and broadcast it to all other TP ranks.""" + if self.attn_tp_rank == 0: + recv_reqs = [] + + while True: + try: + recv_req = self._recv_from_tokenizer.recv_pyobj(zmq.NOBLOCK) + except zmq.ZMQError: + break + recv_reqs.append(recv_req) + else: + recv_reqs = None + + if self.server_args.enable_dp_attention: + if self.attn_tp_rank == 0: + work_reqs = [ + req + for req in recv_reqs + if isinstance( + req, (TokenizedGenerateReqInput, TokenizedEmbeddingReqInput) + ) + ] + control_reqs = [ + req + for req in recv_reqs + if not isinstance( + req, (TokenizedGenerateReqInput, TokenizedEmbeddingReqInput) + ) + ] + else: + work_reqs = None + control_reqs = None + + if self.attn_tp_size != 1: + attn_tp_rank_0 = self.dp_rank * self.attn_tp_size + work_reqs = broadcast_pyobj( + work_reqs, + self.attn_tp_rank, + self.attn_tp_cpu_group, + src=attn_tp_rank_0, + ) + if self.tp_size != 1: + control_reqs = broadcast_pyobj( + control_reqs, self.tp_rank, self.tp_cpu_group + ) + recv_reqs = work_reqs + control_reqs + elif self.tp_size != 1: + recv_reqs = broadcast_pyobj(recv_reqs, self.tp_rank, self.tp_cpu_group) + return recv_reqs + def _process_input_requests(self, recv_reqs: List): for recv_req in recv_reqs: output = self._request_dispatcher(recv_req) From 479158880b202a2fd96688a5d18ea3e914d50336 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Sat, 25 Jan 2025 10:31:05 +0800 Subject: [PATCH 071/136] update run process --- python/sglang/srt/orchestration/std/scheduler.py | 15 +++++++++++---- 1 file changed, 11 insertions(+), 4 deletions(-) diff --git a/python/sglang/srt/orchestration/std/scheduler.py b/python/sglang/srt/orchestration/std/scheduler.py index 6dfaf9deb59..15752f3e709 100644 --- a/python/sglang/srt/orchestration/std/scheduler.py +++ b/python/sglang/srt/orchestration/std/scheduler.py @@ -189,6 +189,13 @@ def run_scheduler_process( # Create a scheduler and run the event loop try: scheduler = Scheduler(server_args, port_args, gpu_id, tp_rank, dp_rank) + communicator = SchedulerCommunicator( + core=scheduler, + server_args=server_args, + port_args=port_args, + tp_rank=tp_rank, + ) + pipe_writer.send( { "status": "ready", @@ -196,10 +203,10 @@ def run_scheduler_process( "max_req_input_len": scheduler.max_req_input_len, } ) - if scheduler.enable_overlap: - scheduler.event_loop_overlap() - else: - scheduler.event_loop_normal() + + while True: + communicator.recv_and_process_input_requests() + scheduler.process_batch() except Exception: traceback = get_exception_traceback() logger.error(f"Scheduler hit an exception: {traceback}") From b54bf6c1d989695691d26d348115816cfd45b89a Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Sat, 25 Jan 2025 10:32:29 +0800 Subject: [PATCH 072/136] extract process_batch --- python/sglang/srt/managers/scheduler.py | 97 ++++++++++++------------- 1 file changed, 46 insertions(+), 51 deletions(-) diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index 883daf60c1c..ca28dd2a2c3 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -285,6 +285,7 @@ def __init__( self.last_decode_stats_tic = time.time() self.stream_interval = server_args.stream_interval self.current_stream = torch.get_device_module(self.device).current_stream() + self.overlap_result_queue = deque() if self.device == "cpu": self.current_stream.synchronize = lambda: None # No-op for CPU @@ -391,65 +392,59 @@ def watchdog_thread(self): time.sleep(5) self.parent_process.send_signal(signal.SIGQUIT) - @torch.no_grad() - def event_loop_normal(self): - """A normal scheduler loop.""" - while True: - recv_reqs = self.recv_requests() - self.process_input_requests(recv_reqs) + def process_batch(self): + if self.enable_overlap: + return self._process_batch_overlap() + else: + return self._process_batch_normal() - batch = self.get_next_batch_to_run() - self.cur_batch = batch + @torch.no_grad() + def _process_batch_normal(self): + batch = self.get_next_batch_to_run() + self.cur_batch = batch - if batch: - result = self.run_batch(batch) - self.process_batch_result(batch, result) - else: - # When the server is idle, so self-check and re-init some states - self.check_memory() - self.new_token_ratio = self.init_new_token_ratio + if batch: + result = self.run_batch(batch) + self.process_batch_result(batch, result) + else: + # When the server is idle, so self-check and re-init some states + self.check_memory() + self.new_token_ratio = self.init_new_token_ratio - self.last_batch = batch + self.last_batch = batch @torch.no_grad() def event_loop_overlap(self): - """A scheduler loop that overlaps the CPU processing and GPU computation.""" - result_queue = deque() - - while True: - recv_reqs = self.recv_requests() - self.process_input_requests(recv_reqs) - - batch = self.get_next_batch_to_run() - self.cur_batch = batch - - if batch: - result = self.run_batch(batch) - result_queue.append((batch.copy(), result)) - - if self.last_batch is None: - # Create a dummy first batch to start the pipeline for overlap schedule. - # It is now used for triggering the sampling_info_done event. - tmp_batch = ScheduleBatch( - reqs=None, - forward_mode=ForwardMode.DUMMY_FIRST, - next_batch_sampling_info=self.tp_worker.cur_sampling_info, - ) - self.process_batch_result(tmp_batch, None) - - if self.last_batch: - # Process the results of the last batch - tmp_batch, tmp_result = result_queue.popleft() - tmp_batch.next_batch_sampling_info = ( - self.tp_worker.cur_sampling_info if batch else None + batch = self.get_next_batch_to_run() + self.cur_batch = batch + + if batch: + result = self.run_batch(batch) + self.overlap_result_queue.append((batch.copy(), result)) + + if self.last_batch is None: + # Create a dummy first batch to start the pipeline for overlap schedule. + # It is now used for triggering the sampling_info_done event. + tmp_batch = ScheduleBatch( + reqs=None, + forward_mode=ForwardMode.DUMMY_FIRST, + next_batch_sampling_info=self.tp_worker.cur_sampling_info, ) - self.process_batch_result(tmp_batch, tmp_result) - elif batch is None: - # When the server is idle, so self-check and re-init some states - self.check_memory() - self.new_token_ratio = self.init_new_token_ratio + self.process_batch_result(tmp_batch, None) + + if self.last_batch: + # Process the results of the last batch + tmp_batch, tmp_result = self.overlap_result_queue.popleft() + tmp_batch.next_batch_sampling_info = ( + self.tp_worker.cur_sampling_info if batch else None + ) + self.process_batch_result(tmp_batch, tmp_result) + elif batch is None: + # When the server is idle, so self-check and re-init some states + self.check_memory() + self.new_token_ratio = self.init_new_token_ratio - self.last_batch = batch + self.last_batch = batch def handle_generate_request( self, From f252b0fec925d4d1ae82ac51e9dfe79b3f208d37 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Sat, 25 Jan 2025 10:33:23 +0800 Subject: [PATCH 073/136] fix name --- python/sglang/srt/managers/scheduler.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index ca28dd2a2c3..03b33151300 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -414,7 +414,7 @@ def _process_batch_normal(self): self.last_batch = batch @torch.no_grad() - def event_loop_overlap(self): + def _process_batch_overlap(self): batch = self.get_next_batch_to_run() self.cur_batch = batch From 231898120f1c97e1294b09bb8eb628f9bf2685a5 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Sat, 25 Jan 2025 10:35:14 +0800 Subject: [PATCH 074/136] fix field name --- .../sglang/srt/orchestration/std/scheduler.py | 25 ++++++++----------- 1 file changed, 11 insertions(+), 14 deletions(-) diff --git a/python/sglang/srt/orchestration/std/scheduler.py b/python/sglang/srt/orchestration/std/scheduler.py index 15752f3e709..4fddd7dc44b 100644 --- a/python/sglang/srt/orchestration/std/scheduler.py +++ b/python/sglang/srt/orchestration/std/scheduler.py @@ -32,15 +32,13 @@ def __init__( core: Scheduler, server_args: ServerArgs, port_args: PortArgs, - tp_rank: int, ): self.core = core self.server_args = server_args - self.tp_rank = tp_rank # Init inter-process communication context = zmq.Context(2) - if self.attn_tp_rank == 0: + if self.core.attn_tp_rank == 0: self._recv_from_tokenizer = get_zmq_socket( context, zmq.PULL, port_args.scheduler_input_ipc_name, False ) @@ -99,7 +97,7 @@ def recv_and_process_input_requests(self): def _recv_requests(self) -> List[Req]: """Receive results at tp_rank = 0 and broadcast it to all other TP ranks.""" - if self.attn_tp_rank == 0: + if self.core.attn_tp_rank == 0: recv_reqs = [] while True: @@ -112,7 +110,7 @@ def _recv_requests(self) -> List[Req]: recv_reqs = None if self.server_args.enable_dp_attention: - if self.attn_tp_rank == 0: + if self.core.attn_tp_rank == 0: work_reqs = [ req for req in recv_reqs @@ -131,21 +129,21 @@ def _recv_requests(self) -> List[Req]: work_reqs = None control_reqs = None - if self.attn_tp_size != 1: - attn_tp_rank_0 = self.dp_rank * self.attn_tp_size + if self.core.attn_tp_size != 1: + attn_tp_rank_0 = self.core.dp_rank * self.core.attn_tp_size work_reqs = broadcast_pyobj( work_reqs, - self.attn_tp_rank, - self.attn_tp_cpu_group, + self.core.attn_tp_rank, + self.core.attn_tp_cpu_group, src=attn_tp_rank_0, ) - if self.tp_size != 1: + if self.core.tp_size != 1: control_reqs = broadcast_pyobj( - control_reqs, self.tp_rank, self.tp_cpu_group + control_reqs, self.core.tp_rank, self.core.tp_cpu_group ) recv_reqs = work_reqs + control_reqs - elif self.tp_size != 1: - recv_reqs = broadcast_pyobj(recv_reqs, self.tp_rank, self.tp_cpu_group) + elif self.core.tp_size != 1: + recv_reqs = broadcast_pyobj(recv_reqs, self.tp_rank, self.core.tp_cpu_group) return recv_reqs def _process_input_requests(self, recv_reqs: List): @@ -193,7 +191,6 @@ def run_scheduler_process( core=scheduler, server_args=server_args, port_args=port_args, - tp_rank=tp_rank, ) pipe_writer.send( From 6f4f2c039df927fc3f4896ffea41e1325f72b6df Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Sat, 25 Jan 2025 10:35:46 +0800 Subject: [PATCH 075/136] simp args --- python/sglang/srt/managers/scheduler.py | 8 ++++---- python/sglang/srt/orchestration/std/scheduler.py | 2 +- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index 03b33151300..f0679cc0164 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -78,7 +78,7 @@ from sglang.srt.mem_cache.radix_cache import RadixCache from sglang.srt.metrics.collector import SchedulerMetricsCollector, SchedulerStats from sglang.srt.model_executor.forward_batch_info import ForwardMode -from sglang.srt.server_args import PortArgs, ServerArgs +from sglang.srt.server_args import ServerArgs from sglang.srt.speculative.spec_info import SpeculativeAlgorithm from sglang.srt.torch_memory_saver_adapter import TorchMemorySaverAdapter from sglang.srt.utils import ( @@ -112,7 +112,7 @@ class Scheduler: def __init__( self, server_args: ServerArgs, - port_args: PortArgs, + nccl_port: int, gpu_id: int, tp_rank: int, dp_rank: Optional[int], @@ -204,7 +204,7 @@ def __init__( gpu_id=gpu_id, tp_rank=tp_rank, dp_rank=dp_rank, - nccl_port=port_args.nccl_port, + nccl_port=nccl_port, ) # Launch a worker for speculative decoding if needed @@ -215,7 +215,7 @@ def __init__( gpu_id=gpu_id, tp_rank=tp_rank, server_args=server_args, - nccl_port=port_args.nccl_port, + nccl_port=nccl_port, target_worker=self.tp_worker, dp_rank=dp_rank, ) diff --git a/python/sglang/srt/orchestration/std/scheduler.py b/python/sglang/srt/orchestration/std/scheduler.py index 4fddd7dc44b..521000eaa3b 100644 --- a/python/sglang/srt/orchestration/std/scheduler.py +++ b/python/sglang/srt/orchestration/std/scheduler.py @@ -186,7 +186,7 @@ def run_scheduler_process( # Create a scheduler and run the event loop try: - scheduler = Scheduler(server_args, port_args, gpu_id, tp_rank, dp_rank) + scheduler = Scheduler(server_args, port_args.nccl_port, gpu_id, tp_rank, dp_rank) communicator = SchedulerCommunicator( core=scheduler, server_args=server_args, From 4de6e76550c4f13a183a676b95ce0037e0deac6d Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Sat, 25 Jan 2025 10:35:59 +0800 Subject: [PATCH 076/136] fmt --- .../srt/managers/detokenizer_manager.py | 1 - python/sglang/srt/managers/scheduler.py | 65 +++++++++---------- 2 files changed, 31 insertions(+), 35 deletions(-) diff --git a/python/sglang/srt/managers/detokenizer_manager.py b/python/sglang/srt/managers/detokenizer_manager.py index e9d88d8bbe8..9a8b3659bdb 100644 --- a/python/sglang/srt/managers/detokenizer_manager.py +++ b/python/sglang/srt/managers/detokenizer_manager.py @@ -208,4 +208,3 @@ def __setitem__(self, key, value): self.popitem(last=False) # Set the new item super().__setitem__(key, value) - diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index f0679cc0164..c14f58c448f 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -23,10 +23,11 @@ from concurrent import futures from dataclasses import dataclass from http import HTTPStatus -from typing import Dict, List, Optional, Tuple, Union, Callable +from typing import Callable, Dict, List, Optional, Tuple, Union import psutil import torch + from sglang.global_config import global_config from sglang.srt.configs.model_config import ModelConfig from sglang.srt.constrained.base_grammar_backend import create_grammar_backend @@ -81,11 +82,7 @@ from sglang.srt.server_args import ServerArgs from sglang.srt.speculative.spec_info import SpeculativeAlgorithm from sglang.srt.torch_memory_saver_adapter import TorchMemorySaverAdapter -from sglang.srt.utils import ( - crash_on_warnings, - get_bool_env_var, - set_random_seed, -) +from sglang.srt.utils import crash_on_warnings, get_bool_env_var, set_random_seed logger = logging.getLogger(__name__) @@ -326,8 +323,8 @@ def __init__( 1.0, ) self.new_token_ratio_decay = ( - self.init_new_token_ratio - self.min_new_token_ratio - ) / global_config.default_new_token_ratio_decay_steps + self.init_new_token_ratio - self.min_new_token_ratio + ) / global_config.default_new_token_ratio_decay_steps self.new_token_ratio = self.init_new_token_ratio # Tells whether the current running batch is full so that we can skip @@ -607,9 +604,9 @@ def handle_embedding_request( def log_prefill_stats(self, adder, can_run_list, running_bs, has_being_chunked): self.tree_cache_metrics["total"] += ( - adder.log_input_tokens + adder.log_hit_tokens - ) / 10 ** 9 - self.tree_cache_metrics["hit"] += (adder.log_hit_tokens) / 10 ** 9 + adder.log_input_tokens + adder.log_hit_tokens + ) / 10**9 + self.tree_cache_metrics["hit"] += (adder.log_hit_tokens) / 10**9 tree_cache_hit_rate = ( self.tree_cache_metrics["hit"] / self.tree_cache_metrics["total"] ) @@ -788,10 +785,10 @@ def get_new_batch_prefill(self) -> Optional[ScheduleBatch]: if ( self.lora_paths and len( - lora_set - | set([req.lora_path for req in adder.can_run_list]) - | set([req.lora_path]) - ) + lora_set + | set([req.lora_path for req in adder.can_run_list]) + | set([req.lora_path]) + ) > self.max_loras_per_batch ): self.batch_is_full = True @@ -1011,7 +1008,7 @@ def process_batch_result_prefill( if self.is_mixed_chunk and self.enable_overlap and req.finished(): # Free the one delayed token for the mixed decode batch j = len(batch.out_cache_loc) - len(batch.reqs) + i - self.token_to_kv_pool.free(batch.out_cache_loc[j: j + 1]) + self.token_to_kv_pool.free(batch.out_cache_loc[j : j + 1]) continue if req.is_being_chunked <= 0: @@ -1098,7 +1095,7 @@ def process_batch_result_decode( if self.enable_overlap and req.finished(): # Free the one delayed token - self.token_to_kv_pool.free(batch.out_cache_loc[i: i + 1]) + self.token_to_kv_pool.free(batch.out_cache_loc[i : i + 1]) continue if batch.spec_algorithm.is_none(): @@ -1158,15 +1155,15 @@ def add_logprob_return_values( if req.input_token_logprobs_val is None: input_token_logprobs_val = output.input_token_logprobs[ - pt: pt + num_input_logprobs - 1 - req.last_update_decode_tokens - ] + pt : pt + num_input_logprobs - 1 - req.last_update_decode_tokens + ] input_token_logprobs_idx = req.fill_ids[ - len(req.fill_ids) - - num_input_logprobs - + 1: len(req.fill_ids) - - req.last_update_decode_tokens - ] + len(req.fill_ids) + - num_input_logprobs + + 1 : len(req.fill_ids) + - req.last_update_decode_tokens + ] # Clip the padded hash values from image tokens. # Otherwise, it will lead to detokenization errors. input_token_logprobs_idx = [ @@ -1187,18 +1184,18 @@ def add_logprob_return_values( # Some decode tokens are re-computed in an extend batch req.output_token_logprobs_val.extend( output.input_token_logprobs[ - pt - + num_input_logprobs - - 1 - - req.last_update_decode_tokens: pt - + num_input_logprobs - - 1 + pt + + num_input_logprobs + - 1 + - req.last_update_decode_tokens : pt + + num_input_logprobs + - 1 ], ) req.output_token_logprobs_idx.extend( req.fill_ids[ - len(req.fill_ids) - - req.last_update_decode_tokens: len(req.fill_ids) + len(req.fill_ids) + - req.last_update_decode_tokens : len(req.fill_ids) ] ) @@ -1212,10 +1209,10 @@ def add_logprob_return_values( if req.last_update_decode_tokens != 0: req.output_top_logprobs_val.extend( - output.input_top_logprobs_val[i][-req.last_update_decode_tokens:] + output.input_top_logprobs_val[i][-req.last_update_decode_tokens :] ) req.output_top_logprobs_idx.extend( - output.input_top_logprobs_idx[i][-req.last_update_decode_tokens:] + output.input_top_logprobs_idx[i][-req.last_update_decode_tokens :] ) req.output_top_logprobs_val.append(output.next_token_top_logprobs_val[i]) From 776d35a55aa06bb561e97e288219fd758bf5913d Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Sat, 25 Jan 2025 10:36:16 +0800 Subject: [PATCH 077/136] fix err --- python/sglang/srt/orchestration/std/scheduler.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/sglang/srt/orchestration/std/scheduler.py b/python/sglang/srt/orchestration/std/scheduler.py index 521000eaa3b..3a27a9a38a4 100644 --- a/python/sglang/srt/orchestration/std/scheduler.py +++ b/python/sglang/srt/orchestration/std/scheduler.py @@ -143,7 +143,7 @@ def _recv_requests(self) -> List[Req]: ) recv_reqs = work_reqs + control_reqs elif self.core.tp_size != 1: - recv_reqs = broadcast_pyobj(recv_reqs, self.tp_rank, self.core.tp_cpu_group) + recv_reqs = broadcast_pyobj(recv_reqs, self.core.tp_rank, self.core.tp_cpu_group) return recv_reqs def _process_input_requests(self, recv_reqs: List): From b0fa6c09d434ee00861f17221563f0a488d3a45f Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Sat, 25 Jan 2025 10:37:35 +0800 Subject: [PATCH 078/136] mv _launch_subprocesses --- python/sglang/srt/entrypoints/engine.py | 176 +-------------- .../sglang/srt/orchestration/std/launcher.py | 201 ++++++++++++++++++ 2 files changed, 202 insertions(+), 175 deletions(-) create mode 100644 python/sglang/srt/orchestration/std/launcher.py diff --git a/python/sglang/srt/entrypoints/engine.py b/python/sglang/srt/entrypoints/engine.py index 5eaa8a6d441..255f416beed 100644 --- a/python/sglang/srt/entrypoints/engine.py +++ b/python/sglang/srt/entrypoints/engine.py @@ -21,23 +21,16 @@ import atexit import dataclasses import logging -import multiprocessing as mp import os -import signal import threading from typing import AsyncIterator, Dict, Iterator, List, Optional, Tuple, Union -from sglang.srt.orchestration.std.detokenizer import run_detokenizer_process - # Fix a bug of Python threading setattr(threading, "_register_atexit", lambda *args, **kwargs: None) import torch import uvloop -from sglang.srt.managers.data_parallel_controller import ( - run_data_parallel_controller_process, -) from sglang.srt.managers.io_struct import ( EmbeddingReqInput, GenerateReqInput, @@ -48,20 +41,10 @@ UpdateWeightsFromDistributedReqInput, UpdateWeightsFromTensorReqInput, ) -from sglang.srt.managers.scheduler import run_scheduler_process -from sglang.srt.openai_api.adapter import load_chat_template_for_openai_api -from sglang.srt.orchestration.std.orchestrator import StdOrchestrator -from sglang.srt.server_args import PortArgs, ServerArgs -from sglang.srt.torch_memory_saver_adapter import TorchMemorySaverAdapter +from sglang.srt.server_args import ServerArgs from sglang.srt.utils import ( MultiprocessingSerializer, - assert_pkg_version, - configure_logger, kill_process_tree, - maybe_set_triton_cache_manager, - prepare_model_and_tokenizer, - set_prometheus_multiproc_dir, - set_ulimit, ) from sglang.version import __version__ @@ -287,160 +270,3 @@ def resume_memory_occupation(self): return loop.run_until_complete( self.orchestrator.resume_memory_occupation(obj, None) ) - - -def _set_envs_and_config(server_args: ServerArgs): - # Set global environments - os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3" - os.environ["NCCL_CUMEM_ENABLE"] = "0" - os.environ["NCCL_NVLS_ENABLE"] = "0" - os.environ["TORCH_NCCL_AVOID_RECORD_STREAMS"] = "1" - os.environ["CUDA_DEVICE_MAX_CONNECTIONS"] = "4" - - # Set prometheus env vars - if server_args.enable_metrics: - set_prometheus_multiproc_dir() - - # Set ulimit - set_ulimit() - - # Fix triton bugs - if server_args.tp_size * server_args.dp_size > 1: - # FIXME: remove this after https://github.com/triton-lang/triton/pull/4295 is used as a dependency. - maybe_set_triton_cache_manager() - - # Check flashinfer version - if server_args.attention_backend == "flashinfer": - assert_pkg_version( - "flashinfer", - "0.1.6", - "Please uninstall the old version and " - "reinstall the latest version by following the instructions " - "at https://docs.flashinfer.ai/installation.html.", - ) - - # Register the signal handler. - # The child processes will send SIGQUIT to this process when any error happens - # This process then clean up the whole process tree - def sigquit_handler(signum, frame): - logger.error( - "Received sigquit from a child proces. It usually means the child failed." - ) - kill_process_tree(os.getpid()) - - signal.signal(signal.SIGQUIT, sigquit_handler) - - # Set mp start method - mp.set_start_method("spawn", force=True) - - -def _launch_subprocesses(server_args: ServerArgs) -> Tuple[StdOrchestrator, Dict]: - """ - Launch the StdOrchestrator in the main process, the Scheduler in a subprocess, and the DetokenizerManager in another subprocess. - """ - # Configure global environment - configure_logger(server_args) - server_args.check_server_args() - _set_envs_and_config(server_args) - - # Allocate ports for inter-process communications - port_args = PortArgs.init_new(server_args) - logger.info(f"{server_args=}") - - # If using model from www.modelscope.cn, first download the model. - server_args.model_path, server_args.tokenizer_path = prepare_model_and_tokenizer( - server_args.model_path, server_args.tokenizer_path - ) - - scheduler_procs = [] - if server_args.dp_size == 1: - # Launch tensor parallel scheduler processes - memory_saver_adapter = TorchMemorySaverAdapter.create( - enable=server_args.enable_memory_saver - ) - - scheduler_pipe_readers = [] - tp_size_per_node = server_args.tp_size // server_args.nnodes - tp_rank_range = range( - tp_size_per_node * server_args.node_rank, - tp_size_per_node * (server_args.node_rank + 1), - ) - for tp_rank in tp_rank_range: - reader, writer = mp.Pipe(duplex=False) - gpu_id = server_args.base_gpu_id + tp_rank % tp_size_per_node - proc = mp.Process( - target=run_scheduler_process, - args=(server_args, port_args, gpu_id, tp_rank, None, writer), - ) - with memory_saver_adapter.configure_subprocess(): - proc.start() - scheduler_procs.append(proc) - scheduler_pipe_readers.append(reader) - else: - # Launch the data parallel controller - reader, writer = mp.Pipe(duplex=False) - scheduler_pipe_readers = [reader] - proc = mp.Process( - target=run_data_parallel_controller_process, - args=(server_args, port_args, writer), - ) - proc.start() - scheduler_procs.append(proc) - - if server_args.node_rank >= 1: - # In multi-node cases, non-zero rank nodes do not need to run tokenizer or detokenizer, - # so they can just wait here. - - for reader in scheduler_pipe_readers: - data = reader.recv() - assert data["status"] == "ready" - - if os.getenv("SGLANG_BLOCK_NONZERO_RANK_CHILDREN") == "0": - # When using `Engine` as a Python API, we don't want to block here. - return - - for proc in scheduler_procs: - proc.join() - logger.error( - f"Scheduler or DataParallelController {proc.pid} terminated with {proc.exitcode}" - ) - return - - # Launch detokenizer process - detoken_proc = mp.Process( - target=run_detokenizer_process, - args=( - server_args, - port_args, - ), - ) - detoken_proc.start() - - # Launch tokenizer process - orchestrator = StdOrchestrator(server_args, port_args) - if server_args.chat_template: - load_chat_template_for_openai_api(orchestrator, server_args.chat_template) - - # Wait for the model to finish loading - scheduler_infos = [] - for i in range(len(scheduler_pipe_readers)): - try: - data = scheduler_pipe_readers[i].recv() - except EOFError: - logger.error( - f"Rank {i} scheduler is dead. Please check if there are relevant logs." - ) - scheduler_procs[i].join() - logger.error(f"Exit code: {scheduler_procs[i].exitcode}") - raise - - if data["status"] != "ready": - raise RuntimeError( - "Initialization failed. Please see the error messages above." - ) - scheduler_infos.append(data) - - # Assume all schedulers have the same scheduler_info - scheduler_info = scheduler_infos[0] - orchestrator.configure_max_req_input_len(scheduler_info["max_req_input_len"]) - return orchestrator, scheduler_info diff --git a/python/sglang/srt/orchestration/std/launcher.py b/python/sglang/srt/orchestration/std/launcher.py new file mode 100644 index 00000000000..48a58fbe64e --- /dev/null +++ b/python/sglang/srt/orchestration/std/launcher.py @@ -0,0 +1,201 @@ +import asyncio +import atexit +import dataclasses +import logging +import multiprocessing as mp +import os +import signal +import threading +from typing import AsyncIterator, Dict, Iterator, List, Optional, Tuple, Union + +from sglang.srt.orchestration.std.detokenizer import run_detokenizer_process + +import torch +import uvloop + +from sglang.srt.managers.data_parallel_controller import ( + run_data_parallel_controller_process, +) +from sglang.srt.managers.io_struct import ( + EmbeddingReqInput, + GenerateReqInput, + GetWeightsByNameReqInput, + InitWeightsUpdateGroupReqInput, + ReleaseMemoryOccupationReqInput, + ResumeMemoryOccupationReqInput, + UpdateWeightsFromDistributedReqInput, + UpdateWeightsFromTensorReqInput, +) +from sglang.srt.openai_api.adapter import load_chat_template_for_openai_api +from sglang.srt.orchestration.std.orchestrator import StdOrchestrator +from sglang.srt.orchestration.std.scheduler import run_scheduler_process +from sglang.srt.server_args import PortArgs, ServerArgs +from sglang.srt.torch_memory_saver_adapter import TorchMemorySaverAdapter +from sglang.srt.utils import ( + MultiprocessingSerializer, + assert_pkg_version, + configure_logger, + kill_process_tree, + maybe_set_triton_cache_manager, + prepare_model_and_tokenizer, + set_prometheus_multiproc_dir, + set_ulimit, +) +from sglang.version import __version__ + +def _launch_subprocesses(server_args: ServerArgs) -> Tuple[StdOrchestrator, Dict]: + """ + Launch the StdOrchestrator in the main process, the Scheduler in a subprocess, and the DetokenizerManager in another subprocess. + """ + # Configure global environment + configure_logger(server_args) + server_args.check_server_args() + _set_envs_and_config(server_args) + + # Allocate ports for inter-process communications + port_args = PortArgs.init_new(server_args) + logger.info(f"{server_args=}") + + # If using model from www.modelscope.cn, first download the model. + server_args.model_path, server_args.tokenizer_path = prepare_model_and_tokenizer( + server_args.model_path, server_args.tokenizer_path + ) + + scheduler_procs = [] + if server_args.dp_size == 1: + # Launch tensor parallel scheduler processes + memory_saver_adapter = TorchMemorySaverAdapter.create( + enable=server_args.enable_memory_saver + ) + + scheduler_pipe_readers = [] + tp_size_per_node = server_args.tp_size // server_args.nnodes + tp_rank_range = range( + tp_size_per_node * server_args.node_rank, + tp_size_per_node * (server_args.node_rank + 1), + ) + for tp_rank in tp_rank_range: + reader, writer = mp.Pipe(duplex=False) + gpu_id = server_args.base_gpu_id + tp_rank % tp_size_per_node + proc = mp.Process( + target=run_scheduler_process, + args=(server_args, port_args, gpu_id, tp_rank, None, writer), + ) + with memory_saver_adapter.configure_subprocess(): + proc.start() + scheduler_procs.append(proc) + scheduler_pipe_readers.append(reader) + else: + # Launch the data parallel controller + reader, writer = mp.Pipe(duplex=False) + scheduler_pipe_readers = [reader] + proc = mp.Process( + target=run_data_parallel_controller_process, + args=(server_args, port_args, writer), + ) + proc.start() + scheduler_procs.append(proc) + + if server_args.node_rank >= 1: + # In multi-node cases, non-zero rank nodes do not need to run tokenizer or detokenizer, + # so they can just wait here. + + for reader in scheduler_pipe_readers: + data = reader.recv() + assert data["status"] == "ready" + + if os.getenv("SGLANG_BLOCK_NONZERO_RANK_CHILDREN") == "0": + # When using `Engine` as a Python API, we don't want to block here. + return + + for proc in scheduler_procs: + proc.join() + logger.error( + f"Scheduler or DataParallelController {proc.pid} terminated with {proc.exitcode}" + ) + return + + # Launch detokenizer process + detoken_proc = mp.Process( + target=run_detokenizer_process, + args=( + server_args, + port_args, + ), + ) + detoken_proc.start() + + # Launch tokenizer process + orchestrator = StdOrchestrator(server_args, port_args) + if server_args.chat_template: + load_chat_template_for_openai_api(orchestrator, server_args.chat_template) + + # Wait for the model to finish loading + scheduler_infos = [] + for i in range(len(scheduler_pipe_readers)): + try: + data = scheduler_pipe_readers[i].recv() + except EOFError: + logger.error( + f"Rank {i} scheduler is dead. Please check if there are relevant logs." + ) + scheduler_procs[i].join() + logger.error(f"Exit code: {scheduler_procs[i].exitcode}") + raise + + if data["status"] != "ready": + raise RuntimeError( + "Initialization failed. Please see the error messages above." + ) + scheduler_infos.append(data) + + # Assume all schedulers have the same scheduler_info + scheduler_info = scheduler_infos[0] + orchestrator.configure_max_req_input_len(scheduler_info["max_req_input_len"]) + return orchestrator, scheduler_info + +def _set_envs_and_config(server_args: ServerArgs): + # Set global environments + os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3" + os.environ["NCCL_CUMEM_ENABLE"] = "0" + os.environ["NCCL_NVLS_ENABLE"] = "0" + os.environ["TORCH_NCCL_AVOID_RECORD_STREAMS"] = "1" + os.environ["CUDA_DEVICE_MAX_CONNECTIONS"] = "4" + + # Set prometheus env vars + if server_args.enable_metrics: + set_prometheus_multiproc_dir() + + # Set ulimit + set_ulimit() + + # Fix triton bugs + if server_args.tp_size * server_args.dp_size > 1: + # FIXME: remove this after https://github.com/triton-lang/triton/pull/4295 is used as a dependency. + maybe_set_triton_cache_manager() + + # Check flashinfer version + if server_args.attention_backend == "flashinfer": + assert_pkg_version( + "flashinfer", + "0.1.6", + "Please uninstall the old version and " + "reinstall the latest version by following the instructions " + "at https://docs.flashinfer.ai/installation.html.", + ) + + # Register the signal handler. + # The child processes will send SIGQUIT to this process when any error happens + # This process then clean up the whole process tree + def sigquit_handler(signum, frame): + logger.error( + "Received sigquit from a child proces. It usually means the child failed." + ) + kill_process_tree(os.getpid()) + + signal.signal(signal.SIGQUIT, sigquit_handler) + + # Set mp start method + mp.set_start_method("spawn", force=True) + + From 5f073a1d3ac607c5b6058e7a62b56c55e4046740 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Sat, 25 Jan 2025 10:37:57 +0800 Subject: [PATCH 079/136] rename --- python/sglang/srt/entrypoints/engine.py | 4 +++- python/sglang/srt/entrypoints/http_server.py | 7 ++++--- python/sglang/srt/orchestration/std/launcher.py | 2 +- 3 files changed, 8 insertions(+), 5 deletions(-) diff --git a/python/sglang/srt/entrypoints/engine.py b/python/sglang/srt/entrypoints/engine.py index 255f416beed..572b5e165e1 100644 --- a/python/sglang/srt/entrypoints/engine.py +++ b/python/sglang/srt/entrypoints/engine.py @@ -25,6 +25,8 @@ import threading from typing import AsyncIterator, Dict, Iterator, List, Optional, Tuple, Union +from sglang.srt.orchestration.std.launcher import launch + # Fix a bug of Python threading setattr(threading, "_register_atexit", lambda *args, **kwargs: None) @@ -85,7 +87,7 @@ def __init__(self, **kwargs): atexit.register(self.shutdown) # Launch subprocesses - orchestrator, scheduler_info = _launch_subprocesses(server_args=server_args) + orchestrator, scheduler_info = launch(server_args=server_args) self.orchestrator = orchestrator self.scheduler_info = scheduler_info diff --git a/python/sglang/srt/entrypoints/http_server.py b/python/sglang/srt/entrypoints/http_server.py index d949877373a..19ab591787a 100644 --- a/python/sglang/srt/entrypoints/http_server.py +++ b/python/sglang/srt/entrypoints/http_server.py @@ -27,6 +27,8 @@ from http import HTTPStatus from typing import AsyncIterator, Dict, Optional +from sglang.srt.orchestration.std.launcher import launch + # Fix a bug of Python threading setattr(threading, "_register_atexit", lambda *args, **kwargs: None) @@ -38,7 +40,6 @@ from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import ORJSONResponse, Response, StreamingResponse -from sglang.srt.entrypoints.engine import _launch_subprocesses from sglang.srt.managers.io_struct import ( CloseSessionReqInput, ConfigureLoggingReq, @@ -227,7 +228,7 @@ async def flush_cache(): _global_state.orchestrator.flush_cache() return Response( content="Cache flushed.\nPlease check backend logs for more details. " - "(When there are running or waiting requests, the operation will not be performed.)\n", + "(When there are running or waiting requests, the operation will not be performed.)\n", status_code=200, ) @@ -462,7 +463,7 @@ def launch_server( 1. The HTTP server, Engine, and StdOrchestrator both run in the main process. 2. Inter-process communication is done through ICP (each process uses a different port) via the ZMQ library. """ - orchestrator, scheduler_info = _launch_subprocesses(server_args=server_args) + orchestrator, scheduler_info = launch(server_args=server_args) set_global_state( _GlobalState( orchestrator=orchestrator, diff --git a/python/sglang/srt/orchestration/std/launcher.py b/python/sglang/srt/orchestration/std/launcher.py index 48a58fbe64e..73f66967c5c 100644 --- a/python/sglang/srt/orchestration/std/launcher.py +++ b/python/sglang/srt/orchestration/std/launcher.py @@ -43,7 +43,7 @@ ) from sglang.version import __version__ -def _launch_subprocesses(server_args: ServerArgs) -> Tuple[StdOrchestrator, Dict]: +def launch(server_args: ServerArgs) -> Tuple[StdOrchestrator, Dict]: """ Launch the StdOrchestrator in the main process, the Scheduler in a subprocess, and the DetokenizerManager in another subprocess. """ From ed2e250d1bd64b705fbaf67e3e93cbfbcda1083b Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Sat, 25 Jan 2025 10:38:57 +0800 Subject: [PATCH 080/136] fix logger --- python/sglang/srt/orchestration/std/launcher.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/python/sglang/srt/orchestration/std/launcher.py b/python/sglang/srt/orchestration/std/launcher.py index 73f66967c5c..2a670ee135a 100644 --- a/python/sglang/srt/orchestration/std/launcher.py +++ b/python/sglang/srt/orchestration/std/launcher.py @@ -43,6 +43,9 @@ ) from sglang.version import __version__ +logger = logging.getLogger(__name__) +asyncio.set_event_loop_policy(uvloop.EventLoopPolicy()) + def launch(server_args: ServerArgs) -> Tuple[StdOrchestrator, Dict]: """ Launch the StdOrchestrator in the main process, the Scheduler in a subprocess, and the DetokenizerManager in another subprocess. From 8c8ca3fb98c741f85640c10bb072e14d45dde597 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Sat, 25 Jan 2025 10:50:37 +0800 Subject: [PATCH 081/136] fmt --- python/sglang/srt/entrypoints/engine.py | 5 +-- python/sglang/srt/entrypoints/http_server.py | 2 +- .../srt/orchestration/std/detokenizer.py | 5 +-- .../sglang/srt/orchestration/std/launcher.py | 9 +++-- .../sglang/srt/orchestration/std/scheduler.py | 35 ++++++++++++++----- 5 files changed, 36 insertions(+), 20 deletions(-) diff --git a/python/sglang/srt/entrypoints/engine.py b/python/sglang/srt/entrypoints/engine.py index 572b5e165e1..80013913417 100644 --- a/python/sglang/srt/entrypoints/engine.py +++ b/python/sglang/srt/entrypoints/engine.py @@ -44,10 +44,7 @@ UpdateWeightsFromTensorReqInput, ) from sglang.srt.server_args import ServerArgs -from sglang.srt.utils import ( - MultiprocessingSerializer, - kill_process_tree, -) +from sglang.srt.utils import MultiprocessingSerializer, kill_process_tree from sglang.version import __version__ logger = logging.getLogger(__name__) diff --git a/python/sglang/srt/entrypoints/http_server.py b/python/sglang/srt/entrypoints/http_server.py index 19ab591787a..0ffec4ef334 100644 --- a/python/sglang/srt/entrypoints/http_server.py +++ b/python/sglang/srt/entrypoints/http_server.py @@ -228,7 +228,7 @@ async def flush_cache(): _global_state.orchestrator.flush_cache() return Response( content="Cache flushed.\nPlease check backend logs for more details. " - "(When there are running or waiting requests, the operation will not be performed.)\n", + "(When there are running or waiting requests, the operation will not be performed.)\n", status_code=200, ) diff --git a/python/sglang/srt/orchestration/std/detokenizer.py b/python/sglang/srt/orchestration/std/detokenizer.py index 12c2e8712e7..6673280e9ba 100644 --- a/python/sglang/srt/orchestration/std/detokenizer.py +++ b/python/sglang/srt/orchestration/std/detokenizer.py @@ -4,10 +4,11 @@ import psutil import setproctitle import zmq + from sglang.srt.managers.detokenizer_manager import DetokenizerManager -from sglang.srt.managers.io_struct import BatchTokenIDOut, BatchEmbeddingOut +from sglang.srt.managers.io_struct import BatchEmbeddingOut, BatchTokenIDOut from sglang.srt.server_args import PortArgs, ServerArgs -from sglang.srt.utils import get_zmq_socket, configure_logger +from sglang.srt.utils import configure_logger, get_zmq_socket from sglang.utils import TypeBasedDispatcher, get_exception_traceback logger = logging.getLogger(__name__) diff --git a/python/sglang/srt/orchestration/std/launcher.py b/python/sglang/srt/orchestration/std/launcher.py index 2a670ee135a..edc36b4fd4a 100644 --- a/python/sglang/srt/orchestration/std/launcher.py +++ b/python/sglang/srt/orchestration/std/launcher.py @@ -8,8 +8,6 @@ import threading from typing import AsyncIterator, Dict, Iterator, List, Optional, Tuple, Union -from sglang.srt.orchestration.std.detokenizer import run_detokenizer_process - import torch import uvloop @@ -27,6 +25,7 @@ UpdateWeightsFromTensorReqInput, ) from sglang.srt.openai_api.adapter import load_chat_template_for_openai_api +from sglang.srt.orchestration.std.detokenizer import run_detokenizer_process from sglang.srt.orchestration.std.orchestrator import StdOrchestrator from sglang.srt.orchestration.std.scheduler import run_scheduler_process from sglang.srt.server_args import PortArgs, ServerArgs @@ -46,6 +45,7 @@ logger = logging.getLogger(__name__) asyncio.set_event_loop_policy(uvloop.EventLoopPolicy()) + def launch(server_args: ServerArgs) -> Tuple[StdOrchestrator, Dict]: """ Launch the StdOrchestrator in the main process, the Scheduler in a subprocess, and the DetokenizerManager in another subprocess. @@ -76,7 +76,7 @@ def launch(server_args: ServerArgs) -> Tuple[StdOrchestrator, Dict]: tp_rank_range = range( tp_size_per_node * server_args.node_rank, tp_size_per_node * (server_args.node_rank + 1), - ) + ) for tp_rank in tp_rank_range: reader, writer = mp.Pipe(duplex=False) gpu_id = server_args.base_gpu_id + tp_rank % tp_size_per_node @@ -157,6 +157,7 @@ def launch(server_args: ServerArgs) -> Tuple[StdOrchestrator, Dict]: orchestrator.configure_max_req_input_len(scheduler_info["max_req_input_len"]) return orchestrator, scheduler_info + def _set_envs_and_config(server_args: ServerArgs): # Set global environments os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3" @@ -200,5 +201,3 @@ def sigquit_handler(signum, frame): # Set mp start method mp.set_start_method("spawn", force=True) - - diff --git a/python/sglang/srt/orchestration/std/scheduler.py b/python/sglang/srt/orchestration/std/scheduler.py index 3a27a9a38a4..d4623d7cb14 100644 --- a/python/sglang/srt/orchestration/std/scheduler.py +++ b/python/sglang/srt/orchestration/std/scheduler.py @@ -3,24 +3,39 @@ import os import signal from types import SimpleNamespace -from typing import Optional, List +from typing import List, Optional import psutil import setproctitle import zmq -from sglang.srt.managers.io_struct import TokenizedGenerateReqInput, TokenizedEmbeddingReqInput, FlushCacheReq, \ - AbortReq, UpdateWeightFromDiskReqInput, InitWeightsUpdateGroupReqInput, UpdateWeightsFromDistributedReqInput, \ - UpdateWeightsFromTensorReqInput, GetWeightsByNameReqInput, ProfileReq, OpenSessionReqInput, CloseSessionReqInput, \ - ReleaseMemoryOccupationReqInput, ResumeMemoryOccupationReqInput + +from sglang.srt.managers.io_struct import ( + AbortReq, + CloseSessionReqInput, + FlushCacheReq, + GetWeightsByNameReqInput, + InitWeightsUpdateGroupReqInput, + OpenSessionReqInput, + ProfileReq, + ReleaseMemoryOccupationReqInput, + ResumeMemoryOccupationReqInput, + TokenizedEmbeddingReqInput, + TokenizedGenerateReqInput, + UpdateWeightFromDiskReqInput, + UpdateWeightsFromDistributedReqInput, + UpdateWeightsFromTensorReqInput, +) from sglang.srt.managers.schedule_batch import Req from sglang.srt.managers.scheduler import Scheduler from sglang.srt.server_args import PortArgs, ServerArgs from sglang.srt.utils import ( + broadcast_pyobj, configure_logger, get_bool_env_var, get_zmq_socket, set_gpu_proc_affinity, - suppress_other_loggers, broadcast_pyobj, ) + suppress_other_loggers, +) from sglang.utils import TypeBasedDispatcher, get_exception_traceback logger = logging.getLogger(__name__) @@ -143,7 +158,9 @@ def _recv_requests(self) -> List[Req]: ) recv_reqs = work_reqs + control_reqs elif self.core.tp_size != 1: - recv_reqs = broadcast_pyobj(recv_reqs, self.core.tp_rank, self.core.tp_cpu_group) + recv_reqs = broadcast_pyobj( + recv_reqs, self.core.tp_rank, self.core.tp_cpu_group + ) return recv_reqs def _process_input_requests(self, recv_reqs: List): @@ -186,7 +203,9 @@ def run_scheduler_process( # Create a scheduler and run the event loop try: - scheduler = Scheduler(server_args, port_args.nccl_port, gpu_id, tp_rank, dp_rank) + scheduler = Scheduler( + server_args, port_args.nccl_port, gpu_id, tp_rank, dp_rank + ) communicator = SchedulerCommunicator( core=scheduler, server_args=server_args, From 40b39b744e4eae760b80cf6ab8916b9fe78020b5 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Sat, 25 Jan 2025 10:54:15 +0800 Subject: [PATCH 082/136] empty file --- python/sglang/srt/orchestration/spmd/__init__.py | 0 python/sglang/srt/orchestration/spmd/orchestrator.py | 0 2 files changed, 0 insertions(+), 0 deletions(-) create mode 100644 python/sglang/srt/orchestration/spmd/__init__.py create mode 100644 python/sglang/srt/orchestration/spmd/orchestrator.py diff --git a/python/sglang/srt/orchestration/spmd/__init__.py b/python/sglang/srt/orchestration/spmd/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/python/sglang/srt/orchestration/spmd/orchestrator.py b/python/sglang/srt/orchestration/spmd/orchestrator.py new file mode 100644 index 00000000000..e69de29bb2d From 7ea2563f342d4fafccdeca0451434855413022bc Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Sat, 25 Jan 2025 10:54:57 +0800 Subject: [PATCH 083/136] cp spmd_orchestrator from old pr --- .../srt/orchestration/spmd/orchestrator.py | 56 +++++++++++++++++++ 1 file changed, 56 insertions(+) diff --git a/python/sglang/srt/orchestration/spmd/orchestrator.py b/python/sglang/srt/orchestration/spmd/orchestrator.py index e69de29bb2d..5ae15688726 100644 --- a/python/sglang/srt/orchestration/spmd/orchestrator.py +++ b/python/sglang/srt/orchestration/spmd/orchestrator.py @@ -0,0 +1,56 @@ +from typing import Any, Dict, List + +from sglang.srt.managers.detokenizer_manager import DetokenizerManager +from sglang.srt.managers.generation_manager import GenerationConverter +from sglang.srt.managers.io_struct import BatchTokenIDOut, GenerateReqInput +from sglang.srt.managers.scheduler import Scheduler +from sglang.srt.server_args import ServerArgs + + +class SpmdOrchestrator: + def __init__( + self, + server_args: ServerArgs, + nccl_port: int, + gpu_id: int, + tp_rank: int, + ): + self._scheduler = Scheduler( + server_args=server_args, + nccl_port=nccl_port, + gpu_id=gpu_id, + tp_rank=tp_rank, + dp_rank=None, + ) + self._generation_converter = GenerationConverter(server_args) + self._detokenizer = DetokenizerManager(server_args) + + def generate(self, obj: GenerateReqInput): + obj.normalize_batch_and_arguments() + tokenized_requests = self._generation_converter.tokenize_requests(obj) + rid_to_req_index = {r.rid: i for i, r in enumerate(tokenized_requests)} + + outputs: List[Dict[str, Any]] = [None] * obj.batch_size + + def _handle_scheduler_output(batch_token_id_out: BatchTokenIDOut): + batch_str_out = self._detokenizer.handle_batch_token_id_out( + batch_token_id_out + ) + for output_index in range(len(batch_str_out.rids)): + req_index = rid_to_req_index[batch_str_out.rids[output_index]] + outputs[req_index] = self._generation_converter.postprocess_response( + batch_str_out, index=output_index, req_obj=obj[req_index] + ) + + self._scheduler.on_generation_output = _handle_scheduler_output + + for tokenized_request in tokenized_requests: + self._scheduler.handle_generate_request(tokenized_request) + + while self._scheduler.process_batch(): + pass + + return outputs + + def shutdown(self): + self._scheduler.shutdown() From 15da0ca7acc4b763df3ba36e9f410b48dbdffdbe Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Sat, 25 Jan 2025 10:56:17 +0800 Subject: [PATCH 084/136] cp shutdown from old pr --- python/sglang/srt/managers/scheduler.py | 60 ++++++++++--------- python/sglang/srt/managers/tp_worker.py | 3 + .../srt/managers/tp_worker_overlap_thread.py | 12 ++-- 3 files changed, 41 insertions(+), 34 deletions(-) diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index c14f58c448f..ead1370f8f2 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -27,7 +27,6 @@ import psutil import torch - from sglang.global_config import global_config from sglang.srt.configs.model_config import ModelConfig from sglang.srt.constrained.base_grammar_backend import create_grammar_backend @@ -323,8 +322,8 @@ def __init__( 1.0, ) self.new_token_ratio_decay = ( - self.init_new_token_ratio - self.min_new_token_ratio - ) / global_config.default_new_token_ratio_decay_steps + self.init_new_token_ratio - self.min_new_token_ratio + ) / global_config.default_new_token_ratio_decay_steps self.new_token_ratio = self.init_new_token_ratio # Tells whether the current running batch is full so that we can skip @@ -604,9 +603,9 @@ def handle_embedding_request( def log_prefill_stats(self, adder, can_run_list, running_bs, has_being_chunked): self.tree_cache_metrics["total"] += ( - adder.log_input_tokens + adder.log_hit_tokens - ) / 10**9 - self.tree_cache_metrics["hit"] += (adder.log_hit_tokens) / 10**9 + adder.log_input_tokens + adder.log_hit_tokens + ) / 10 ** 9 + self.tree_cache_metrics["hit"] += (adder.log_hit_tokens) / 10 ** 9 tree_cache_hit_rate = ( self.tree_cache_metrics["hit"] / self.tree_cache_metrics["total"] ) @@ -785,10 +784,10 @@ def get_new_batch_prefill(self) -> Optional[ScheduleBatch]: if ( self.lora_paths and len( - lora_set - | set([req.lora_path for req in adder.can_run_list]) - | set([req.lora_path]) - ) + lora_set + | set([req.lora_path for req in adder.can_run_list]) + | set([req.lora_path]) + ) > self.max_loras_per_batch ): self.batch_is_full = True @@ -1008,7 +1007,7 @@ def process_batch_result_prefill( if self.is_mixed_chunk and self.enable_overlap and req.finished(): # Free the one delayed token for the mixed decode batch j = len(batch.out_cache_loc) - len(batch.reqs) + i - self.token_to_kv_pool.free(batch.out_cache_loc[j : j + 1]) + self.token_to_kv_pool.free(batch.out_cache_loc[j: j + 1]) continue if req.is_being_chunked <= 0: @@ -1095,7 +1094,7 @@ def process_batch_result_decode( if self.enable_overlap and req.finished(): # Free the one delayed token - self.token_to_kv_pool.free(batch.out_cache_loc[i : i + 1]) + self.token_to_kv_pool.free(batch.out_cache_loc[i: i + 1]) continue if batch.spec_algorithm.is_none(): @@ -1155,15 +1154,15 @@ def add_logprob_return_values( if req.input_token_logprobs_val is None: input_token_logprobs_val = output.input_token_logprobs[ - pt : pt + num_input_logprobs - 1 - req.last_update_decode_tokens - ] + pt: pt + num_input_logprobs - 1 - req.last_update_decode_tokens + ] input_token_logprobs_idx = req.fill_ids[ - len(req.fill_ids) - - num_input_logprobs - + 1 : len(req.fill_ids) - - req.last_update_decode_tokens - ] + len(req.fill_ids) + - num_input_logprobs + + 1: len(req.fill_ids) + - req.last_update_decode_tokens + ] # Clip the padded hash values from image tokens. # Otherwise, it will lead to detokenization errors. input_token_logprobs_idx = [ @@ -1184,18 +1183,18 @@ def add_logprob_return_values( # Some decode tokens are re-computed in an extend batch req.output_token_logprobs_val.extend( output.input_token_logprobs[ - pt - + num_input_logprobs - - 1 - - req.last_update_decode_tokens : pt - + num_input_logprobs - - 1 + pt + + num_input_logprobs + - 1 + - req.last_update_decode_tokens: pt + + num_input_logprobs + - 1 ], ) req.output_token_logprobs_idx.extend( req.fill_ids[ - len(req.fill_ids) - - req.last_update_decode_tokens : len(req.fill_ids) + len(req.fill_ids) + - req.last_update_decode_tokens: len(req.fill_ids) ] ) @@ -1209,10 +1208,10 @@ def add_logprob_return_values( if req.last_update_decode_tokens != 0: req.output_top_logprobs_val.extend( - output.input_top_logprobs_val[i][-req.last_update_decode_tokens :] + output.input_top_logprobs_val[i][-req.last_update_decode_tokens:] ) req.output_top_logprobs_idx.extend( - output.input_top_logprobs_idx[i][-req.last_update_decode_tokens :] + output.input_top_logprobs_idx[i][-req.last_update_decode_tokens:] ) req.output_top_logprobs_val.append(output.next_token_top_logprobs_val[i]) @@ -1577,6 +1576,9 @@ def close_session(self, recv_req: CloseSessionReqInput): else: del self.sessions[session_id] + def shutdown(self): + self.tp_worker.shutdown() + def _export_static_state(model): return dict( diff --git a/python/sglang/srt/managers/tp_worker.py b/python/sglang/srt/managers/tp_worker.py index fd4dbae9900..931a690b4da 100644 --- a/python/sglang/srt/managers/tp_worker.py +++ b/python/sglang/srt/managers/tp_worker.py @@ -214,3 +214,6 @@ def get_weights_by_name(self, recv_req: GetWeightsByNameReqInput): recv_req.name, recv_req.truncate_size ) return parameter + + def shutdown(self): + pass diff --git a/python/sglang/srt/managers/tp_worker_overlap_thread.py b/python/sglang/srt/managers/tp_worker_overlap_thread.py index 961b0bbdc11..13a5edf8967 100644 --- a/python/sglang/srt/managers/tp_worker_overlap_thread.py +++ b/python/sglang/srt/managers/tp_worker_overlap_thread.py @@ -22,7 +22,6 @@ import psutil import torch - from sglang.srt.managers.io_struct import ( GetWeightsByNameReqInput, InitWeightsUpdateGroupReqInput, @@ -144,7 +143,7 @@ def forward_thread_func_(self): # Update the future token ids map bs = len(model_worker_batch.seq_lens) self.future_token_ids_map[ - future_token_ids_ct + 1 : future_token_ids_ct + bs + 1 + future_token_ids_ct + 1: future_token_ids_ct + bs + 1 ] = next_token_ids # Copy results to the CPU @@ -204,8 +203,8 @@ def forward_batch_generation(self, model_worker_batch: ModelWorkerBatch): device=self.device, ) self.future_token_ids_ct = ( - self.future_token_ids_ct + bs - ) % self.future_token_ids_limit + self.future_token_ids_ct + bs + ) % self.future_token_ids_limit return None, future_next_token_ids def update_weights_from_disk(self, recv_req: UpdateWeightFromDiskReqInput): @@ -230,5 +229,8 @@ def get_weights_by_name(self, recv_req: GetWeightsByNameReqInput): return self.worker.get_weights_by_name(recv_req) def __delete__(self): + self.shutdown() + + def shutdown(self): self.input_queue.put((None, None)) - self.copy_queue.put((None, None, None)) + # self.copy_queue.put((None, None, None)) # the queue seems no longer exist From bff551cfe52cccad6409b7be3910b1b14b8f9cd2 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Sat, 25 Jan 2025 10:57:16 +0800 Subject: [PATCH 085/136] cp engine_base --- python/sglang/srt/entrypoints/engine_base.py | 34 ++++++++++++++++++++ 1 file changed, 34 insertions(+) create mode 100644 python/sglang/srt/entrypoints/engine_base.py diff --git a/python/sglang/srt/entrypoints/engine_base.py b/python/sglang/srt/entrypoints/engine_base.py new file mode 100644 index 00000000000..5e260b29328 --- /dev/null +++ b/python/sglang/srt/entrypoints/engine_base.py @@ -0,0 +1,34 @@ +from typing import Dict, List, Optional, Union + +from sglang.srt.managers.io_struct import GenerateReqInput + + +class EngineBase: + # Make it in base class to ensure API is exactly the same, as well as extracting common logic + def generate( + self, + # The input prompt. It can be a single prompt or a batch of prompts. + prompt: Optional[Union[List[str], str]] = None, + sampling_params: Optional[Union[List[Dict], Dict]] = None, + # The token ids for text; one can either specify text or input_ids. + input_ids: Optional[Union[List[List[int]], List[int]]] = None, + return_logprob: Optional[Union[List[bool], bool]] = False, + logprob_start_len: Optional[Union[List[int], int]] = None, + top_logprobs_num: Optional[Union[List[int], int]] = None, + lora_path: Optional[List[Optional[str]]] = None, + stream: bool = False, + ): + obj = GenerateReqInput( + text=prompt, + input_ids=input_ids, + sampling_params=sampling_params, + return_logprob=return_logprob, + logprob_start_len=logprob_start_len, + top_logprobs_num=top_logprobs_num, + lora_path=lora_path, + stream=stream, + ) + return self._generate_impl(obj) + + def _generate_impl(self, obj: GenerateReqInput): + raise NotImplementedError From 970e3571312ce349e2da427bf6065543ef91ad34 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Sat, 25 Jan 2025 10:57:52 +0800 Subject: [PATCH 086/136] extend base class --- python/sglang/srt/entrypoints/engine.py | 34 ++++--------------------- 1 file changed, 5 insertions(+), 29 deletions(-) diff --git a/python/sglang/srt/entrypoints/engine.py b/python/sglang/srt/entrypoints/engine.py index 80013913417..2cac5d06618 100644 --- a/python/sglang/srt/entrypoints/engine.py +++ b/python/sglang/srt/entrypoints/engine.py @@ -23,8 +23,9 @@ import logging import os import threading -from typing import AsyncIterator, Dict, Iterator, List, Optional, Tuple, Union +from typing import AsyncIterator, Dict, List, Optional, Tuple, Union +from sglang.srt.entrypoints.engine_base import EngineBase from sglang.srt.orchestration.std.launcher import launch # Fix a bug of Python threading @@ -51,7 +52,7 @@ asyncio.set_event_loop_policy(uvloop.EventLoopPolicy()) -class Engine: +class Engine(EngineBase): """ The entry point to the inference engine. @@ -88,40 +89,15 @@ def __init__(self, **kwargs): self.orchestrator = orchestrator self.scheduler_info = scheduler_info - def generate( - self, - # The input prompt. It can be a single prompt or a batch of prompts. - prompt: Optional[Union[List[str], str]] = None, - sampling_params: Optional[Union[List[Dict], Dict]] = None, - # The token ids for text; one can either specify text or input_ids. - input_ids: Optional[Union[List[List[int]], List[int]]] = None, - return_logprob: Optional[Union[List[bool], bool]] = False, - logprob_start_len: Optional[Union[List[int], int]] = None, - top_logprobs_num: Optional[Union[List[int], int]] = None, - lora_path: Optional[List[Optional[str]]] = None, - custom_logit_processor: Optional[Union[List[str], str]] = None, - stream: bool = False, - ) -> Union[Dict, Iterator[Dict]]: + def _generate_impl(self, obj: GenerateReqInput): """ The arguments of this function is the same as `sglang/srt/managers/io_struct.py::GenerateReqInput`. Please refer to `GenerateReqInput` for the documentation. """ - obj = GenerateReqInput( - text=prompt, - input_ids=input_ids, - sampling_params=sampling_params, - return_logprob=return_logprob, - logprob_start_len=logprob_start_len, - top_logprobs_num=top_logprobs_num, - lora_path=lora_path, - custom_logit_processor=custom_logit_processor, - stream=stream, - ) loop = asyncio.get_event_loop() generator = self.orchestrator.generate_request(obj, None) - if stream: - + if obj.stream: def generator_wrapper(): while True: try: From 418250544007ecd86aba46da9305ba2f9b29437d Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Sat, 25 Jan 2025 10:58:36 +0800 Subject: [PATCH 087/136] cp engine_fragment from old pr --- .../sglang/srt/entrypoints/engine_fragment.py | 29 +++++++++++++++++++ 1 file changed, 29 insertions(+) create mode 100644 python/sglang/srt/entrypoints/engine_fragment.py diff --git a/python/sglang/srt/entrypoints/engine_fragment.py b/python/sglang/srt/entrypoints/engine_fragment.py new file mode 100644 index 00000000000..49a5cd4fcd2 --- /dev/null +++ b/python/sglang/srt/entrypoints/engine_fragment.py @@ -0,0 +1,29 @@ +from sglang.srt.entrypoints.engine_base import EngineBase +from sglang.srt.managers.io_struct import GenerateReqInput +from sglang.srt.orchestration.spmd.orchestrator import SpmdOrchestrator +from sglang.srt.server_args import ServerArgs + + +class EngineFragment(EngineBase): + def __init__( + self, + nccl_port: int, + gpu_id: int, + tp_rank: int, + log_level: str = "error", + *args, + **kwargs, + ): + server_args = ServerArgs(*args, log_level=log_level, **kwargs) + self._entrypoint = SpmdOrchestrator( + server_args=server_args, + nccl_port=nccl_port, + gpu_id=gpu_id, + tp_rank=tp_rank, + ) + + def _generate_impl(self, obj: GenerateReqInput): + return self._entrypoint.generate(obj) + + def shutdown(self): + self._entrypoint.shutdown() From 2ce877176e2935e320befc2d49fe7cc714662894 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Sat, 25 Jan 2025 10:59:44 +0800 Subject: [PATCH 088/136] cp examples --- .../offline_batch_inference_torchrun.py | 80 +++++++++++++++++++ 1 file changed, 80 insertions(+) create mode 100644 examples/runtime/engine/offline_batch_inference_torchrun.py diff --git a/examples/runtime/engine/offline_batch_inference_torchrun.py b/examples/runtime/engine/offline_batch_inference_torchrun.py new file mode 100644 index 00000000000..5e1d5d8c34e --- /dev/null +++ b/examples/runtime/engine/offline_batch_inference_torchrun.py @@ -0,0 +1,80 @@ +import datetime +import os +import sys + +from sglang.srt.server.engine_fragment import EngineFragment + + +def run(): + """ + Example command: + ``` + torchrun --nproc_per_node=4 offline_batch_inference_torchrun.py + ``` + """ + + local_rank = int(os.environ["LOCAL_RANK"]) + rank = int(os.environ["RANK"]) + world_size = int(os.environ["WORLD_SIZE"]) + + def _log(text): + t = datetime.datetime.now().strftime("%H:%M:%S") + print(f"[{t}] [rank={rank}] {text}") + + _log( + f'start {local_rank=} {rank=} {world_size=} {sys.argv=} {os.environ.get("CUDA_VISIBLE_DEVICES")}' + ) + + tp_size = world_size + tp_rank = rank + _log(f"{tp_rank=} {tp_size=}") + + model_name, mem_fraction_static = "meta-llama/Llama-3.2-1B-Instruct", 0.1 + # model_name, mem_fraction_static = "meta-llama/Llama-3.1-70B-Instruct", 0.9 # test large models + + # TODO remove this in next PR + for k in [ + "GROUP_RANK", + "GROUP_WORLD_SIZE", + "LOCAL_RANK", + "LOCAL_WORLD_SIZE", + "MASTER_ADDR", + "MASTER_PORT", + "NCCL_DEBUG", + "OMP_NUM_THREADS", + "RANK", + "ROLE_NAME", + "ROLE_RANK", + "ROLE_WORLD_SIZE", + "TORCHELASTIC_ERROR_FILE", + "TORCHELASTIC_MAX_RESTARTS", + "TORCHELASTIC_RESTART_COUNT", + "TORCHELASTIC_RUN_ID", + "TORCHELASTIC_USE_AGENT_STORE", + "TORCH_NCCL_ASYNC_ERROR_HANDLING", + "WORLD_SIZE", + ]: + del os.environ[k] + + fragment = EngineFragment( + model_path=model_name, + mem_fraction_static=mem_fraction_static, + tp_size=tp_size, + tp_rank=tp_rank, + nccl_port=23456, + gpu_id=tp_rank, + ) + _log(f"{fragment=}") + + output = fragment.generate( + prompt=["1+1=2, 1+2=3, 1+3=4, 1+4=", "9-1=8, 8-1=7, 7-1="], + sampling_params=dict(max_new_tokens=16, temperature=0.0), + ) + _log(f"{output=}") + + fragment.shutdown() + _log(f"End script") + + +if __name__ == "__main__": + run() From 298604b41e4a1da75f0f64ace91a6fa364928622 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Sat, 25 Jan 2025 11:02:52 +0800 Subject: [PATCH 089/136] update ci --- .github/workflows/pr-test.yml | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/.github/workflows/pr-test.yml b/.github/workflows/pr-test.yml index c5eeeee3c14..85a27bb5973 100644 --- a/.github/workflows/pr-test.yml +++ b/.github/workflows/pr-test.yml @@ -105,6 +105,12 @@ jobs: cd test/srt python3 test_update_weights_from_distributed.py + - name: Test EngineFragment + timeout-minutes: 10 + run: | + cd test/srt + python3 test_fragment.py + - name: Test expert parallelism (EP=2) timeout-minutes: 10 run: | From bda1243ff91d2a8990e70cf485ce8042c1c93a76 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Sat, 25 Jan 2025 11:03:10 +0800 Subject: [PATCH 090/136] empty --- test/srt/test_fragment.py | 0 1 file changed, 0 insertions(+), 0 deletions(-) create mode 100644 test/srt/test_fragment.py diff --git a/test/srt/test_fragment.py b/test/srt/test_fragment.py new file mode 100644 index 00000000000..e69de29bb2d From d41fdc624bca6f42c7c720e72b321174342d4363 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Sat, 25 Jan 2025 11:03:48 +0800 Subject: [PATCH 091/136] cp test from old pr --- test/srt/test_fragment.py | 94 +++++++++++++++++++++++++++++++++++++++ 1 file changed, 94 insertions(+) diff --git a/test/srt/test_fragment.py b/test/srt/test_fragment.py index e69de29bb2d..a4cb056b3eb 100644 --- a/test/srt/test_fragment.py +++ b/test/srt/test_fragment.py @@ -0,0 +1,94 @@ +import multiprocessing +import multiprocessing as mp +import traceback +import unittest +from multiprocessing import Process + +from sglang.srt.server.engine_fragment import EngineFragment +from sglang.test.test_utils import DEFAULT_SMALL_MODEL_NAME_FOR_TEST + +_TP_SIZE = 2 + + +class TestFragment(unittest.TestCase): + def test_fragment(self): + multiprocessing.set_start_method("spawn") + nccl_port = 12345 + + processes = [] + output_reader, output_writer = mp.Pipe(duplex=False) + for tp_rank in range(_TP_SIZE): + p = Process( + target=_run_subprocess, + args=(tp_rank, nccl_port, output_writer), + ) + p.start() + processes.append(p) + + outputs = [output_reader.recv() for _ in range(_TP_SIZE)] + print(outputs) + for output in outputs: + self.assertEqual( + output, + [ + " to spend it outdoors. I decided to take a walk in the nearby park.", + " how to improve the performance of my website. I've been doing some research and", + " a new user of the platform. I am looking for a new laptop to buy", + " I'm looking for someone to help me with a project.\nI'm a student", + " the science of numbers and their properties. It is a vast and complex field that", + ], + ) + + for p in processes: + p.join() + + +def _run_subprocess(tp_rank: int, nccl_port: int, output_writer): + try: + print(f"subprocess[{tp_rank=}] Start") + + fragment = EngineFragment( + model_path=DEFAULT_SMALL_MODEL_NAME_FOR_TEST, + mem_fraction_static=0.1, + tp_size=_TP_SIZE, + random_seed=42, + # fragment args + tp_rank=tp_rank, + gpu_id=tp_rank, + nccl_port=nccl_port, + ) + print(f"subprocess[{tp_rank=}] {fragment=}", flush=True) + + # NOTE: We deliberately call fragment.generate *twice* to confirm this function can be called multiple times + # In real batch generation, surely we should only call fragment.generate once + ans = [] + for prompt in [ + ["Today is a sunny day and I like", "I have a very good idea on"], + ["Hello, I am", "What is your name?", "Mathematics is defined as"], + ]: + print(f"subprocess[{tp_rank=}] Start generation", flush=True) + outputs = fragment.generate( + prompt=prompt, + sampling_params=[dict(max_new_tokens=16, temperature=0.0)] + * len(prompt), + ) + print( + f"subprocess[{tp_rank=}] End generation {prompt=} {outputs=}", + flush=True, + ) + ans += [o["text"] for o in outputs] + + output_writer.send(ans) + output_writer.close() + + except Exception as e: + print(f"subprocess[{tp_rank=}] has error: {e}", flush=True) + traceback.print_exc() + raise + + fragment.shutdown() + print(f"subprocess[{tp_rank=}] end", flush=True) + + +if __name__ == "__main__": + unittest.main() From 6863225cbaf1057e38d5fda983392433fd11c257 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Sat, 25 Jan 2025 11:04:58 +0800 Subject: [PATCH 092/136] fmt --- python/sglang/srt/entrypoints/engine.py | 1 + python/sglang/srt/entrypoints/engine_base.py | 22 +++---- python/sglang/srt/managers/scheduler.py | 57 ++++++++++--------- .../srt/managers/tp_worker_overlap_thread.py | 7 ++- test/srt/test_fragment.py | 2 +- 5 files changed, 46 insertions(+), 43 deletions(-) diff --git a/python/sglang/srt/entrypoints/engine.py b/python/sglang/srt/entrypoints/engine.py index 2cac5d06618..6e5d55bcb34 100644 --- a/python/sglang/srt/entrypoints/engine.py +++ b/python/sglang/srt/entrypoints/engine.py @@ -98,6 +98,7 @@ def _generate_impl(self, obj: GenerateReqInput): generator = self.orchestrator.generate_request(obj, None) if obj.stream: + def generator_wrapper(): while True: try: diff --git a/python/sglang/srt/entrypoints/engine_base.py b/python/sglang/srt/entrypoints/engine_base.py index 5e260b29328..d7ba77b183a 100644 --- a/python/sglang/srt/entrypoints/engine_base.py +++ b/python/sglang/srt/entrypoints/engine_base.py @@ -6,17 +6,17 @@ class EngineBase: # Make it in base class to ensure API is exactly the same, as well as extracting common logic def generate( - self, - # The input prompt. It can be a single prompt or a batch of prompts. - prompt: Optional[Union[List[str], str]] = None, - sampling_params: Optional[Union[List[Dict], Dict]] = None, - # The token ids for text; one can either specify text or input_ids. - input_ids: Optional[Union[List[List[int]], List[int]]] = None, - return_logprob: Optional[Union[List[bool], bool]] = False, - logprob_start_len: Optional[Union[List[int], int]] = None, - top_logprobs_num: Optional[Union[List[int], int]] = None, - lora_path: Optional[List[Optional[str]]] = None, - stream: bool = False, + self, + # The input prompt. It can be a single prompt or a batch of prompts. + prompt: Optional[Union[List[str], str]] = None, + sampling_params: Optional[Union[List[Dict], Dict]] = None, + # The token ids for text; one can either specify text or input_ids. + input_ids: Optional[Union[List[List[int]], List[int]]] = None, + return_logprob: Optional[Union[List[bool], bool]] = False, + logprob_start_len: Optional[Union[List[int], int]] = None, + top_logprobs_num: Optional[Union[List[int], int]] = None, + lora_path: Optional[List[Optional[str]]] = None, + stream: bool = False, ): obj = GenerateReqInput( text=prompt, diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index ead1370f8f2..db565305c89 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -27,6 +27,7 @@ import psutil import torch + from sglang.global_config import global_config from sglang.srt.configs.model_config import ModelConfig from sglang.srt.constrained.base_grammar_backend import create_grammar_backend @@ -322,8 +323,8 @@ def __init__( 1.0, ) self.new_token_ratio_decay = ( - self.init_new_token_ratio - self.min_new_token_ratio - ) / global_config.default_new_token_ratio_decay_steps + self.init_new_token_ratio - self.min_new_token_ratio + ) / global_config.default_new_token_ratio_decay_steps self.new_token_ratio = self.init_new_token_ratio # Tells whether the current running batch is full so that we can skip @@ -603,9 +604,9 @@ def handle_embedding_request( def log_prefill_stats(self, adder, can_run_list, running_bs, has_being_chunked): self.tree_cache_metrics["total"] += ( - adder.log_input_tokens + adder.log_hit_tokens - ) / 10 ** 9 - self.tree_cache_metrics["hit"] += (adder.log_hit_tokens) / 10 ** 9 + adder.log_input_tokens + adder.log_hit_tokens + ) / 10**9 + self.tree_cache_metrics["hit"] += (adder.log_hit_tokens) / 10**9 tree_cache_hit_rate = ( self.tree_cache_metrics["hit"] / self.tree_cache_metrics["total"] ) @@ -784,10 +785,10 @@ def get_new_batch_prefill(self) -> Optional[ScheduleBatch]: if ( self.lora_paths and len( - lora_set - | set([req.lora_path for req in adder.can_run_list]) - | set([req.lora_path]) - ) + lora_set + | set([req.lora_path for req in adder.can_run_list]) + | set([req.lora_path]) + ) > self.max_loras_per_batch ): self.batch_is_full = True @@ -1007,7 +1008,7 @@ def process_batch_result_prefill( if self.is_mixed_chunk and self.enable_overlap and req.finished(): # Free the one delayed token for the mixed decode batch j = len(batch.out_cache_loc) - len(batch.reqs) + i - self.token_to_kv_pool.free(batch.out_cache_loc[j: j + 1]) + self.token_to_kv_pool.free(batch.out_cache_loc[j : j + 1]) continue if req.is_being_chunked <= 0: @@ -1094,7 +1095,7 @@ def process_batch_result_decode( if self.enable_overlap and req.finished(): # Free the one delayed token - self.token_to_kv_pool.free(batch.out_cache_loc[i: i + 1]) + self.token_to_kv_pool.free(batch.out_cache_loc[i : i + 1]) continue if batch.spec_algorithm.is_none(): @@ -1154,15 +1155,15 @@ def add_logprob_return_values( if req.input_token_logprobs_val is None: input_token_logprobs_val = output.input_token_logprobs[ - pt: pt + num_input_logprobs - 1 - req.last_update_decode_tokens - ] + pt : pt + num_input_logprobs - 1 - req.last_update_decode_tokens + ] input_token_logprobs_idx = req.fill_ids[ - len(req.fill_ids) - - num_input_logprobs - + 1: len(req.fill_ids) - - req.last_update_decode_tokens - ] + len(req.fill_ids) + - num_input_logprobs + + 1 : len(req.fill_ids) + - req.last_update_decode_tokens + ] # Clip the padded hash values from image tokens. # Otherwise, it will lead to detokenization errors. input_token_logprobs_idx = [ @@ -1183,18 +1184,18 @@ def add_logprob_return_values( # Some decode tokens are re-computed in an extend batch req.output_token_logprobs_val.extend( output.input_token_logprobs[ - pt - + num_input_logprobs - - 1 - - req.last_update_decode_tokens: pt - + num_input_logprobs - - 1 + pt + + num_input_logprobs + - 1 + - req.last_update_decode_tokens : pt + + num_input_logprobs + - 1 ], ) req.output_token_logprobs_idx.extend( req.fill_ids[ - len(req.fill_ids) - - req.last_update_decode_tokens: len(req.fill_ids) + len(req.fill_ids) + - req.last_update_decode_tokens : len(req.fill_ids) ] ) @@ -1208,10 +1209,10 @@ def add_logprob_return_values( if req.last_update_decode_tokens != 0: req.output_top_logprobs_val.extend( - output.input_top_logprobs_val[i][-req.last_update_decode_tokens:] + output.input_top_logprobs_val[i][-req.last_update_decode_tokens :] ) req.output_top_logprobs_idx.extend( - output.input_top_logprobs_idx[i][-req.last_update_decode_tokens:] + output.input_top_logprobs_idx[i][-req.last_update_decode_tokens :] ) req.output_top_logprobs_val.append(output.next_token_top_logprobs_val[i]) diff --git a/python/sglang/srt/managers/tp_worker_overlap_thread.py b/python/sglang/srt/managers/tp_worker_overlap_thread.py index 13a5edf8967..a7068eae659 100644 --- a/python/sglang/srt/managers/tp_worker_overlap_thread.py +++ b/python/sglang/srt/managers/tp_worker_overlap_thread.py @@ -22,6 +22,7 @@ import psutil import torch + from sglang.srt.managers.io_struct import ( GetWeightsByNameReqInput, InitWeightsUpdateGroupReqInput, @@ -143,7 +144,7 @@ def forward_thread_func_(self): # Update the future token ids map bs = len(model_worker_batch.seq_lens) self.future_token_ids_map[ - future_token_ids_ct + 1: future_token_ids_ct + bs + 1 + future_token_ids_ct + 1 : future_token_ids_ct + bs + 1 ] = next_token_ids # Copy results to the CPU @@ -203,8 +204,8 @@ def forward_batch_generation(self, model_worker_batch: ModelWorkerBatch): device=self.device, ) self.future_token_ids_ct = ( - self.future_token_ids_ct + bs - ) % self.future_token_ids_limit + self.future_token_ids_ct + bs + ) % self.future_token_ids_limit return None, future_next_token_ids def update_weights_from_disk(self, recv_req: UpdateWeightFromDiskReqInput): diff --git a/test/srt/test_fragment.py b/test/srt/test_fragment.py index a4cb056b3eb..f62cf0cb4a6 100644 --- a/test/srt/test_fragment.py +++ b/test/srt/test_fragment.py @@ -70,7 +70,7 @@ def _run_subprocess(tp_rank: int, nccl_port: int, output_writer): outputs = fragment.generate( prompt=prompt, sampling_params=[dict(max_new_tokens=16, temperature=0.0)] - * len(prompt), + * len(prompt), ) print( f"subprocess[{tp_rank=}] End generation {prompt=} {outputs=}", From 77acd55451696f69fef7e6d6cac029c17c22293e Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Sat, 25 Jan 2025 11:14:12 +0800 Subject: [PATCH 093/136] fix import --- .../srt/managers/data_parallel_controller.py | 29 +++++++++---------- 1 file changed, 14 insertions(+), 15 deletions(-) diff --git a/python/sglang/srt/managers/data_parallel_controller.py b/python/sglang/srt/managers/data_parallel_controller.py index 3b959b1ba76..52ba48ca53f 100644 --- a/python/sglang/srt/managers/data_parallel_controller.py +++ b/python/sglang/srt/managers/data_parallel_controller.py @@ -22,13 +22,12 @@ import psutil import setproctitle import zmq - from sglang.srt.layers.dp_attention import compute_dp_attention_world_info from sglang.srt.managers.io_struct import ( TokenizedEmbeddingReqInput, TokenizedGenerateReqInput, ) -from sglang.srt.managers.scheduler import run_scheduler_process +from sglang.srt.orchestration.std.scheduler import run_scheduler_process from sglang.srt.server_args import PortArgs, ServerArgs from sglang.srt.utils import bind_port, configure_logger, get_zmq_socket from sglang.utils import get_exception_traceback @@ -143,11 +142,11 @@ def launch_dp_attention_schedulers(self, server_args, port_args): return dp_port_args def launch_tensor_parallel_group( - self, - server_args: ServerArgs, - port_args: PortArgs, - base_gpu_id: int, - dp_rank: int, + self, + server_args: ServerArgs, + port_args: PortArgs, + base_gpu_id: int, + dp_rank: int, ): if not server_args.enable_dp_attention: logger.info(f"Launch DP{dp_rank} starting at GPU #{base_gpu_id}.") @@ -210,11 +209,11 @@ def event_loop(self): break if isinstance( - recv_req, - ( - TokenizedGenerateReqInput, - TokenizedEmbeddingReqInput, - ), + recv_req, + ( + TokenizedGenerateReqInput, + TokenizedEmbeddingReqInput, + ), ): self.dispatching(recv_req) else: @@ -224,9 +223,9 @@ def event_loop(self): def run_data_parallel_controller_process( - server_args: ServerArgs, - port_args: PortArgs, - pipe_writer, + server_args: ServerArgs, + port_args: PortArgs, + pipe_writer, ): setproctitle.setproctitle("sglang::data_parallel_controller") configure_logger(server_args) From 825db47a8dcbb3f248060012de58cad3fc31afe5 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Sat, 25 Jan 2025 11:15:28 +0800 Subject: [PATCH 094/136] fix import --- examples/runtime/engine/offline_batch_inference_torchrun.py | 2 +- test/srt/test_fragment.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/examples/runtime/engine/offline_batch_inference_torchrun.py b/examples/runtime/engine/offline_batch_inference_torchrun.py index 5e1d5d8c34e..9d37ab8a899 100644 --- a/examples/runtime/engine/offline_batch_inference_torchrun.py +++ b/examples/runtime/engine/offline_batch_inference_torchrun.py @@ -2,7 +2,7 @@ import os import sys -from sglang.srt.server.engine_fragment import EngineFragment +from sglang.srt.entrypoints.engine_fragment import EngineFragment def run(): diff --git a/test/srt/test_fragment.py b/test/srt/test_fragment.py index f62cf0cb4a6..e54642e09c5 100644 --- a/test/srt/test_fragment.py +++ b/test/srt/test_fragment.py @@ -4,7 +4,7 @@ import unittest from multiprocessing import Process -from sglang.srt.server.engine_fragment import EngineFragment +from sglang.srt.entrypoints.engine_fragment import EngineFragment from sglang.test.test_utils import DEFAULT_SMALL_MODEL_NAME_FOR_TEST _TP_SIZE = 2 @@ -70,7 +70,7 @@ def _run_subprocess(tp_rank: int, nccl_port: int, output_writer): outputs = fragment.generate( prompt=prompt, sampling_params=[dict(max_new_tokens=16, temperature=0.0)] - * len(prompt), + * len(prompt), ) print( f"subprocess[{tp_rank=}] End generation {prompt=} {outputs=}", From d0c3de99d44a9a9868747db81d2ca1f10a371ad7 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Sat, 25 Jan 2025 11:18:47 +0800 Subject: [PATCH 095/136] fix err --- python/sglang/srt/managers/scheduler.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index db565305c89..ef3e67cffde 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -409,6 +409,7 @@ def _process_batch_normal(self): self.new_token_ratio = self.init_new_token_ratio self.last_batch = batch + return batch is not None @torch.no_grad() def _process_batch_overlap(self): @@ -442,6 +443,7 @@ def _process_batch_overlap(self): self.new_token_ratio = self.init_new_token_ratio self.last_batch = batch + return batch is not None def handle_generate_request( self, From 5ab9cda35fa9884004884bd5984ee89b9b49d8c2 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Sat, 25 Jan 2025 11:23:20 +0800 Subject: [PATCH 096/136] fix minor --- examples/runtime/engine/offline_batch_inference_torchrun.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/runtime/engine/offline_batch_inference_torchrun.py b/examples/runtime/engine/offline_batch_inference_torchrun.py index 9d37ab8a899..939f747f504 100644 --- a/examples/runtime/engine/offline_batch_inference_torchrun.py +++ b/examples/runtime/engine/offline_batch_inference_torchrun.py @@ -40,7 +40,6 @@ def _log(text): "LOCAL_WORLD_SIZE", "MASTER_ADDR", "MASTER_PORT", - "NCCL_DEBUG", "OMP_NUM_THREADS", "RANK", "ROLE_NAME", @@ -54,7 +53,8 @@ def _log(text): "TORCH_NCCL_ASYNC_ERROR_HANDLING", "WORLD_SIZE", ]: - del os.environ[k] + if k in os.environ: + del os.environ[k] fragment = EngineFragment( model_path=model_name, From c96b014425e123dc73c171dc3fcc4fb468c8dd0a Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Sat, 25 Jan 2025 11:25:05 +0800 Subject: [PATCH 097/136] fmt --- .../srt/managers/data_parallel_controller.py | 27 ++++++++++--------- test/srt/test_fragment.py | 2 +- 2 files changed, 15 insertions(+), 14 deletions(-) diff --git a/python/sglang/srt/managers/data_parallel_controller.py b/python/sglang/srt/managers/data_parallel_controller.py index 52ba48ca53f..cadaffa673c 100644 --- a/python/sglang/srt/managers/data_parallel_controller.py +++ b/python/sglang/srt/managers/data_parallel_controller.py @@ -22,6 +22,7 @@ import psutil import setproctitle import zmq + from sglang.srt.layers.dp_attention import compute_dp_attention_world_info from sglang.srt.managers.io_struct import ( TokenizedEmbeddingReqInput, @@ -142,11 +143,11 @@ def launch_dp_attention_schedulers(self, server_args, port_args): return dp_port_args def launch_tensor_parallel_group( - self, - server_args: ServerArgs, - port_args: PortArgs, - base_gpu_id: int, - dp_rank: int, + self, + server_args: ServerArgs, + port_args: PortArgs, + base_gpu_id: int, + dp_rank: int, ): if not server_args.enable_dp_attention: logger.info(f"Launch DP{dp_rank} starting at GPU #{base_gpu_id}.") @@ -209,11 +210,11 @@ def event_loop(self): break if isinstance( - recv_req, - ( - TokenizedGenerateReqInput, - TokenizedEmbeddingReqInput, - ), + recv_req, + ( + TokenizedGenerateReqInput, + TokenizedEmbeddingReqInput, + ), ): self.dispatching(recv_req) else: @@ -223,9 +224,9 @@ def event_loop(self): def run_data_parallel_controller_process( - server_args: ServerArgs, - port_args: PortArgs, - pipe_writer, + server_args: ServerArgs, + port_args: PortArgs, + pipe_writer, ): setproctitle.setproctitle("sglang::data_parallel_controller") configure_logger(server_args) diff --git a/test/srt/test_fragment.py b/test/srt/test_fragment.py index e54642e09c5..19fb9a185dc 100644 --- a/test/srt/test_fragment.py +++ b/test/srt/test_fragment.py @@ -70,7 +70,7 @@ def _run_subprocess(tp_rank: int, nccl_port: int, output_writer): outputs = fragment.generate( prompt=prompt, sampling_params=[dict(max_new_tokens=16, temperature=0.0)] - * len(prompt), + * len(prompt), ) print( f"subprocess[{tp_rank=}] End generation {prompt=} {outputs=}", From c80e610e213ab8f24f30ddbd02f63ef2a7a8bfbf Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Sat, 25 Jan 2025 11:35:45 +0800 Subject: [PATCH 098/136] fix import --- python/sglang/srt/managers/data_parallel_controller.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/sglang/srt/managers/data_parallel_controller.py b/python/sglang/srt/managers/data_parallel_controller.py index 3b959b1ba76..cadaffa673c 100644 --- a/python/sglang/srt/managers/data_parallel_controller.py +++ b/python/sglang/srt/managers/data_parallel_controller.py @@ -28,7 +28,7 @@ TokenizedEmbeddingReqInput, TokenizedGenerateReqInput, ) -from sglang.srt.managers.scheduler import run_scheduler_process +from sglang.srt.orchestration.std.scheduler import run_scheduler_process from sglang.srt.server_args import PortArgs, ServerArgs from sglang.srt.utils import bind_port, configure_logger, get_zmq_socket from sglang.utils import get_exception_traceback From a1b434e0434cbb512f3eee3c34146755335249a0 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Sat, 25 Jan 2025 11:50:34 +0800 Subject: [PATCH 099/136] fix import --- python/sglang/bench_one_batch.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/sglang/bench_one_batch.py b/python/sglang/bench_one_batch.py index bc7a9c7a1a7..43fe538a616 100644 --- a/python/sglang/bench_one_batch.py +++ b/python/sglang/bench_one_batch.py @@ -57,11 +57,11 @@ import torch.distributed as dist from sglang.srt.configs.model_config import ModelConfig -from sglang.srt.entrypoints.engine import _set_envs_and_config from sglang.srt.hf_transformers_utils import get_tokenizer from sglang.srt.managers.schedule_batch import Req, ScheduleBatch from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_executor.model_runner import ModelRunner +from sglang.srt.orchestration.std.launcher import _set_envs_and_config from sglang.srt.sampling.sampling_params import SamplingParams from sglang.srt.server_args import PortArgs, ServerArgs from sglang.srt.speculative.spec_info import SpeculativeAlgorithm From f889a6c956fc34acf35c8535717a3c83eca01615 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Sat, 25 Jan 2025 12:54:16 +0800 Subject: [PATCH 100/136] fix import --- python/sglang/bench_one_batch.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/sglang/bench_one_batch.py b/python/sglang/bench_one_batch.py index bc7a9c7a1a7..43fe538a616 100644 --- a/python/sglang/bench_one_batch.py +++ b/python/sglang/bench_one_batch.py @@ -57,11 +57,11 @@ import torch.distributed as dist from sglang.srt.configs.model_config import ModelConfig -from sglang.srt.entrypoints.engine import _set_envs_and_config from sglang.srt.hf_transformers_utils import get_tokenizer from sglang.srt.managers.schedule_batch import Req, ScheduleBatch from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_executor.model_runner import ModelRunner +from sglang.srt.orchestration.std.launcher import _set_envs_and_config from sglang.srt.sampling.sampling_params import SamplingParams from sglang.srt.server_args import PortArgs, ServerArgs from sglang.srt.speculative.spec_info import SpeculativeAlgorithm From 4543136eeae0ef6aeb18588ba07ebc8e149e7b8d Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Sun, 26 Jan 2025 12:06:42 +0800 Subject: [PATCH 101/136] bump ci --- python/sglang/srt/managers/generation_manager.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/sglang/srt/managers/generation_manager.py b/python/sglang/srt/managers/generation_manager.py index f87d5d75fe9..5b1602c6c3c 100644 --- a/python/sglang/srt/managers/generation_manager.py +++ b/python/sglang/srt/managers/generation_manager.py @@ -652,7 +652,7 @@ def background_task(): @dataclasses.dataclass class _ReqState: - """Store the state a request.""" + """Store the state of a request.""" out_list: List finished: bool From aeed015c8823724578302730fa9c072a432713ef Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Sun, 26 Jan 2025 14:52:00 +0800 Subject: [PATCH 102/136] Revert "bump ci" This reverts commit 4543136eeae0ef6aeb18588ba07ebc8e149e7b8d. --- python/sglang/srt/managers/generation_manager.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/sglang/srt/managers/generation_manager.py b/python/sglang/srt/managers/generation_manager.py index 5b1602c6c3c..f87d5d75fe9 100644 --- a/python/sglang/srt/managers/generation_manager.py +++ b/python/sglang/srt/managers/generation_manager.py @@ -652,7 +652,7 @@ def background_task(): @dataclasses.dataclass class _ReqState: - """Store the state of a request.""" + """Store the state a request.""" out_list: List finished: bool From f62574e67087caf8b6169d640cad3a07e7625d95 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Mon, 27 Jan 2025 12:26:14 +0800 Subject: [PATCH 103/136] fix import --- python/sglang/srt/entrypoints/http_server.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/sglang/srt/entrypoints/http_server.py b/python/sglang/srt/entrypoints/http_server.py index da2beb6b6f9..02c98425d11 100644 --- a/python/sglang/srt/entrypoints/http_server.py +++ b/python/sglang/srt/entrypoints/http_server.py @@ -40,7 +40,6 @@ from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import ORJSONResponse, Response, StreamingResponse -from sglang.srt.entrypoints.engine import _launch_subprocesses from sglang.srt.function_call_parser import FunctionCallParser from sglang.srt.managers.io_struct import ( CloseSessionReqInput, @@ -70,6 +69,7 @@ v1_retrieve_file_content, ) from sglang.srt.openai_api.protocol import ModelCard, ModelList +from sglang.srt.orchestration.std.launcher import launch from sglang.srt.orchestration.std.orchestrator import StdOrchestrator from sglang.srt.server_args import ServerArgs from sglang.srt.utils import ( From 6f057aab75b30c981d1a581844de21f51b527c6f Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Mon, 27 Jan 2025 12:27:52 +0800 Subject: [PATCH 104/136] fix rename --- python/sglang/srt/openai_api/adapter.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/sglang/srt/openai_api/adapter.py b/python/sglang/srt/openai_api/adapter.py index 6e01794c163..08a0f0d7532 100644 --- a/python/sglang/srt/openai_api/adapter.py +++ b/python/sglang/srt/openai_api/adapter.py @@ -915,7 +915,7 @@ def v1_chat_generate_request( # has a different tools input format that is not compatiable # with openAI's apply_chat_template tool_call format, like Mistral. tools = [t if "function" in t else {"function": t} for t in tools] - prompt_ids = tokenizer_manager.tokenizer.apply_chat_template( + prompt_ids = orchestrator.tokenizer.apply_chat_template( openai_compatible_messages, tokenize=True, add_generation_prompt=True, @@ -1282,7 +1282,7 @@ async def generate_stream_resp(): if index not in parser_dict: parser_dict[index] = FunctionCallParser( tools=request.tools, - tool_call_parser=tokenizer_manager.server_args.tool_call_parser, + tool_call_parser=orchestrator.server_args.tool_call_parser, ) parser = parser_dict[index] From bbd790887695656fcd45807b463f055d85ebad03 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Sun, 26 Jan 2025 12:06:42 +0800 Subject: [PATCH 105/136] bump ci --- python/sglang/srt/managers/generation_manager.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/sglang/srt/managers/generation_manager.py b/python/sglang/srt/managers/generation_manager.py index 18686424478..d781daf09ea 100644 --- a/python/sglang/srt/managers/generation_manager.py +++ b/python/sglang/srt/managers/generation_manager.py @@ -655,7 +655,7 @@ def background_task(): @dataclasses.dataclass class _ReqState: - """Store the state a request.""" + """Store the state of a request.""" out_list: List finished: bool From ce8fbef330f3b281f93fd262603514b68787d4ef Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Mon, 27 Jan 2025 12:51:28 +0800 Subject: [PATCH 106/136] fix rename --- docs/backend/function_calling.ipynb | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/backend/function_calling.ipynb b/docs/backend/function_calling.ipynb index 3de80aadf11..efab4ae84c2 100644 --- a/docs/backend/function_calling.ipynb +++ b/docs/backend/function_calling.ipynb @@ -424,7 +424,7 @@ "from sglang.srt.managers.io_struct import Tool, Function\n", "\n", "llm = sgl.Engine(model_path=\"meta-llama/Meta-Llama-3.1-8B-Instruct\")\n", - "tokenizer = llm.tokenizer_manager.tokenizer\n", + "tokenizer = llm.orchestrator.tokenizer\n", "input_ids = tokenizer.apply_chat_template(\n", " messages, tokenize=True, add_generation_prompt=True, tools=tools\n", ")\n", From e669e451e872953588ff132d9a8cf7c4210b773a Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Mon, 27 Jan 2025 13:16:00 +0800 Subject: [PATCH 107/136] Revert "bump ci" This reverts commit bbd790887695656fcd45807b463f055d85ebad03. --- python/sglang/srt/managers/generation_manager.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/sglang/srt/managers/generation_manager.py b/python/sglang/srt/managers/generation_manager.py index d781daf09ea..18686424478 100644 --- a/python/sglang/srt/managers/generation_manager.py +++ b/python/sglang/srt/managers/generation_manager.py @@ -655,7 +655,7 @@ def background_task(): @dataclasses.dataclass class _ReqState: - """Store the state of a request.""" + """Store the state a request.""" out_list: List finished: bool From 5ed169cb87c0dbf7d50ec2ca2071fc9d0ee274ff Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Sun, 26 Jan 2025 12:06:42 +0800 Subject: [PATCH 108/136] bump ci --- python/sglang/srt/managers/generation_manager.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/sglang/srt/managers/generation_manager.py b/python/sglang/srt/managers/generation_manager.py index 18686424478..d781daf09ea 100644 --- a/python/sglang/srt/managers/generation_manager.py +++ b/python/sglang/srt/managers/generation_manager.py @@ -655,7 +655,7 @@ def background_task(): @dataclasses.dataclass class _ReqState: - """Store the state a request.""" + """Store the state of a request.""" out_list: List finished: bool From 3d46e2f337e5620119d59c67b1c7664b6b02939f Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Mon, 27 Jan 2025 13:56:16 +0800 Subject: [PATCH 109/136] Revert "bump ci" This reverts commit 5ed169cb87c0dbf7d50ec2ca2071fc9d0ee274ff. --- python/sglang/srt/managers/generation_manager.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/sglang/srt/managers/generation_manager.py b/python/sglang/srt/managers/generation_manager.py index d781daf09ea..18686424478 100644 --- a/python/sglang/srt/managers/generation_manager.py +++ b/python/sglang/srt/managers/generation_manager.py @@ -655,7 +655,7 @@ def background_task(): @dataclasses.dataclass class _ReqState: - """Store the state of a request.""" + """Store the state a request.""" out_list: List finished: bool From 570d65744058056a90ecf2053302c23e27dfe27c Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Mon, 27 Jan 2025 14:50:14 +0800 Subject: [PATCH 110/136] Revert "Revert "bump ci"" This reverts commit 3d46e2f337e5620119d59c67b1c7664b6b02939f. --- python/sglang/srt/managers/generation_manager.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/sglang/srt/managers/generation_manager.py b/python/sglang/srt/managers/generation_manager.py index 18686424478..d781daf09ea 100644 --- a/python/sglang/srt/managers/generation_manager.py +++ b/python/sglang/srt/managers/generation_manager.py @@ -655,7 +655,7 @@ def background_task(): @dataclasses.dataclass class _ReqState: - """Store the state a request.""" + """Store the state of a request.""" out_list: List finished: bool From 89e433eeec2828acd96ecc0f6b80508bf99ba533 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Mon, 27 Jan 2025 16:25:51 +0800 Subject: [PATCH 111/136] Revert "Revert "Revert "bump ci""" This reverts commit 570d65744058056a90ecf2053302c23e27dfe27c. --- python/sglang/srt/managers/generation_manager.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/sglang/srt/managers/generation_manager.py b/python/sglang/srt/managers/generation_manager.py index d781daf09ea..18686424478 100644 --- a/python/sglang/srt/managers/generation_manager.py +++ b/python/sglang/srt/managers/generation_manager.py @@ -655,7 +655,7 @@ def background_task(): @dataclasses.dataclass class _ReqState: - """Store the state of a request.""" + """Store the state a request.""" out_list: List finished: bool From 528a8344ab048a7b6129f9824850f9da800ed551 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Fri, 21 Feb 2025 16:01:36 +0800 Subject: [PATCH 112/136] merge --- python/sglang/srt/entrypoints/engine.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/sglang/srt/entrypoints/engine.py b/python/sglang/srt/entrypoints/engine.py index 13642f580e4..ede57adce6f 100644 --- a/python/sglang/srt/entrypoints/engine.py +++ b/python/sglang/srt/entrypoints/engine.py @@ -463,5 +463,5 @@ def _launch_subprocesses(server_args: ServerArgs) -> Tuple[TokenizerManager, Dic # Assume all schedulers have the same scheduler_info scheduler_info = scheduler_infos[0] - tokenizer_manager.configure_max_req_input_len(scheduler_info["max_req_input_len"]) + tokenizer_manager.max_req_input_len = scheduler_info["max_req_input_len"] return tokenizer_manager, scheduler_info From 38f4f65aa5536beaa6e7d2dfbd268811c11c49f7 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Fri, 21 Feb 2025 16:01:55 +0800 Subject: [PATCH 113/136] Revert "merge" This reverts commit 528a8344ab048a7b6129f9824850f9da800ed551. --- python/sglang/srt/entrypoints/engine.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/sglang/srt/entrypoints/engine.py b/python/sglang/srt/entrypoints/engine.py index ede57adce6f..13642f580e4 100644 --- a/python/sglang/srt/entrypoints/engine.py +++ b/python/sglang/srt/entrypoints/engine.py @@ -463,5 +463,5 @@ def _launch_subprocesses(server_args: ServerArgs) -> Tuple[TokenizerManager, Dic # Assume all schedulers have the same scheduler_info scheduler_info = scheduler_infos[0] - tokenizer_manager.max_req_input_len = scheduler_info["max_req_input_len"] + tokenizer_manager.configure_max_req_input_len(scheduler_info["max_req_input_len"]) return tokenizer_manager, scheduler_info From f5b5246a7ae1c908a2ac90338a99ad394e51c1c8 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Fri, 21 Feb 2025 16:04:55 +0800 Subject: [PATCH 114/136] merge 3364 --- python/sglang/srt/managers/generation_manager.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/python/sglang/srt/managers/generation_manager.py b/python/sglang/srt/managers/generation_manager.py index 18686424478..41050707e05 100644 --- a/python/sglang/srt/managers/generation_manager.py +++ b/python/sglang/srt/managers/generation_manager.py @@ -468,6 +468,12 @@ def _compute_meta_info(self, index, recv_obj, req_obj): } ) + if ( + hasattr(recv_obj, "output_hidden_states") + and len(recv_obj.output_hidden_states[i]) > 0 + ): + meta_info["hidden_states"] = recv_obj.output_hidden_states[i] + return meta_info def _convert_logprob_style( From 00ee8b28f3b4a3a1c38671a6a5b9f36841a84472 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Fri, 21 Feb 2025 16:28:57 +0800 Subject: [PATCH 115/136] engine.py --- .../sglang/srt/orchestration/std/launcher.py | 29 +++++-------------- 1 file changed, 7 insertions(+), 22 deletions(-) diff --git a/python/sglang/srt/orchestration/std/launcher.py b/python/sglang/srt/orchestration/std/launcher.py index d8db9652978..f8ad272ebed 100644 --- a/python/sglang/srt/orchestration/std/launcher.py +++ b/python/sglang/srt/orchestration/std/launcher.py @@ -1,29 +1,14 @@ import asyncio -import atexit -import dataclasses import logging import multiprocessing as mp import os import signal -import threading -from typing import AsyncIterator, Dict, Iterator, List, Optional, Tuple, Union +from typing import Dict, Tuple -import torch import uvloop - from sglang.srt.managers.data_parallel_controller import ( run_data_parallel_controller_process, ) -from sglang.srt.managers.io_struct import ( - EmbeddingReqInput, - GenerateReqInput, - GetWeightsByNameReqInput, - InitWeightsUpdateGroupReqInput, - ReleaseMemoryOccupationReqInput, - ResumeMemoryOccupationReqInput, - UpdateWeightsFromDistributedReqInput, - UpdateWeightsFromTensorReqInput, -) from sglang.srt.openai_api.adapter import load_chat_template_for_openai_api from sglang.srt.orchestration.std.detokenizer import run_detokenizer_process from sglang.srt.orchestration.std.orchestrator import StdOrchestrator @@ -31,7 +16,6 @@ from sglang.srt.server_args import PortArgs, ServerArgs from sglang.srt.torch_memory_saver_adapter import TorchMemorySaverAdapter from sglang.srt.utils import ( - MultiprocessingSerializer, assert_pkg_version, configure_logger, kill_process_tree, @@ -41,7 +25,6 @@ set_prometheus_multiproc_dir, set_ulimit, ) -from sglang.version import __version__ logger = logging.getLogger(__name__) asyncio.set_event_loop_policy(uvloop.EventLoopPolicy()) @@ -134,7 +117,9 @@ def launch(server_args: ServerArgs) -> Tuple[StdOrchestrator, Dict]: # Launch tokenizer process orchestrator = StdOrchestrator(server_args, port_args) if server_args.chat_template: - load_chat_template_for_openai_api(orchestrator, server_args.chat_template) + load_chat_template_for_openai_api( + orchestrator, server_args.chat_template, server_args.model_path + ) # Wait for the model to finish loading scheduler_infos = [] @@ -165,7 +150,7 @@ def _set_envs_and_config(server_args: ServerArgs): # Set global environments os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3" os.environ["NCCL_CUMEM_ENABLE"] = "0" - os.environ["NCCL_NVLS_ENABLE"] = "0" + os.environ["NCCL_NVLS_ENABLE"] = str(int(server_args.enable_nccl_nvls)) os.environ["TORCH_NCCL_AVOID_RECORD_STREAMS"] = "1" os.environ["CUDA_DEVICE_MAX_CONNECTIONS"] = "4" @@ -184,8 +169,8 @@ def _set_envs_and_config(server_args: ServerArgs): # Check flashinfer version if server_args.attention_backend == "flashinfer": assert_pkg_version( - "flashinfer", - "0.1.6", + "flashinfer_python", + "0.2.1.post2", "Please uninstall the old version and " "reinstall the latest version by following the instructions " "at https://docs.flashinfer.ai/installation.html.", From 6eb3975d1d91e8b674cf151b248ab0647854e6c1 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Fri, 21 Feb 2025 16:31:31 +0800 Subject: [PATCH 116/136] fmt --- python/sglang/srt/orchestration/std/launcher.py | 1 + 1 file changed, 1 insertion(+) diff --git a/python/sglang/srt/orchestration/std/launcher.py b/python/sglang/srt/orchestration/std/launcher.py index f8ad272ebed..343c30ed611 100644 --- a/python/sglang/srt/orchestration/std/launcher.py +++ b/python/sglang/srt/orchestration/std/launcher.py @@ -6,6 +6,7 @@ from typing import Dict, Tuple import uvloop + from sglang.srt.managers.data_parallel_controller import ( run_data_parallel_controller_process, ) From a637e3ea8185195fbd0da7c3198972e52cf2171f Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Fri, 21 Feb 2025 16:31:39 +0800 Subject: [PATCH 117/136] rename result_queue to match upstream --- python/sglang/srt/managers/scheduler.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index b456fa792bf..998828cc8b3 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -284,7 +284,7 @@ def __init__( self.last_decode_stats_tic = time.time() self.stream_interval = server_args.stream_interval self.current_stream = torch.get_device_module(self.device).current_stream() - self.overlap_result_queue = deque() + self.result_queue = deque() if self.device == "cpu": self.current_stream.synchronize = lambda: None # No-op for CPU @@ -424,7 +424,7 @@ def _process_batch_overlap(self): if batch: result = self.run_batch(batch) - self.overlap_result_queue.append((batch.copy(), result)) + self.result_queue.append((batch.copy(), result)) if self.last_batch is None: # Create a dummy first batch to start the pipeline for overlap schedule. @@ -438,7 +438,7 @@ def _process_batch_overlap(self): if self.last_batch: # Process the results of the last batch - tmp_batch, tmp_result = self.overlap_result_queue.popleft() + tmp_batch, tmp_result = self.result_queue.popleft() tmp_batch.next_batch_sampling_info = ( self.tp_worker.cur_sampling_info if batch else None ) From c7739693d3538f794bb4366b3153a94f819041d8 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Fri, 21 Feb 2025 16:32:52 +0800 Subject: [PATCH 118/136] detokenizer_manager --- python/sglang/srt/managers/detokenizer_manager.py | 1 + 1 file changed, 1 insertion(+) diff --git a/python/sglang/srt/managers/detokenizer_manager.py b/python/sglang/srt/managers/detokenizer_manager.py index 980a6eab324..1326e694edc 100644 --- a/python/sglang/srt/managers/detokenizer_manager.py +++ b/python/sglang/srt/managers/detokenizer_manager.py @@ -195,6 +195,7 @@ def handle_batch_token_id_out(self, recv_obj: BatchTokenIDOut): input_top_logprobs_idx=recv_obj.input_top_logprobs_idx, output_top_logprobs_val=recv_obj.output_top_logprobs_val, output_top_logprobs_idx=recv_obj.output_top_logprobs_idx, + output_hidden_states=recv_obj.output_hidden_states, ) From a98540c171d6c5cc2db9ba1c68a72a6bdc622316 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Fri, 21 Feb 2025 16:36:02 +0800 Subject: [PATCH 119/136] engine_base from engine --- python/sglang/srt/entrypoints/engine_base.py | 19 +++++++++++++++++-- 1 file changed, 17 insertions(+), 2 deletions(-) diff --git a/python/sglang/srt/entrypoints/engine_base.py b/python/sglang/srt/entrypoints/engine_base.py index d7ba77b183a..7b6eeeac05f 100644 --- a/python/sglang/srt/entrypoints/engine_base.py +++ b/python/sglang/srt/entrypoints/engine_base.py @@ -1,4 +1,4 @@ -from typing import Dict, List, Optional, Union +from typing import Dict, Iterator, List, Optional, Union from sglang.srt.managers.io_struct import GenerateReqInput @@ -12,20 +12,35 @@ def generate( sampling_params: Optional[Union[List[Dict], Dict]] = None, # The token ids for text; one can either specify text or input_ids. input_ids: Optional[Union[List[List[int]], List[int]]] = None, + # The image input. It can be a file name, a url, or base64 encoded string. + # See also python/sglang/srt/utils.py:load_image. + image_data: Optional[Union[List[str], str]] = None, return_logprob: Optional[Union[List[bool], bool]] = False, logprob_start_len: Optional[Union[List[int], int]] = None, top_logprobs_num: Optional[Union[List[int], int]] = None, lora_path: Optional[List[Optional[str]]] = None, + custom_logit_processor: Optional[Union[List[str], str]] = None, stream: bool = False, - ): + ) -> Union[Dict, Iterator[Dict]]: + """ + The arguments of this function is the same as `sglang/srt/managers/io_struct.py::GenerateReqInput`. + Please refer to `GenerateReqInput` for the documentation. + """ + modalities_list = [] + if image_data is not None: + modalities_list.append("image") + obj = GenerateReqInput( text=prompt, input_ids=input_ids, sampling_params=sampling_params, + image_data=image_data, return_logprob=return_logprob, logprob_start_len=logprob_start_len, top_logprobs_num=top_logprobs_num, lora_path=lora_path, + modalities=modalities_list, + custom_logit_processor=custom_logit_processor, stream=stream, ) return self._generate_impl(obj) From 2850970507bd545e140e163f6214f719a8a3e804 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Fri, 21 Feb 2025 16:37:38 +0800 Subject: [PATCH 120/136] fix err --- python/sglang/srt/managers/generation_manager.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/sglang/srt/managers/generation_manager.py b/python/sglang/srt/managers/generation_manager.py index 41050707e05..e7e86ac359f 100644 --- a/python/sglang/srt/managers/generation_manager.py +++ b/python/sglang/srt/managers/generation_manager.py @@ -470,9 +470,9 @@ def _compute_meta_info(self, index, recv_obj, req_obj): if ( hasattr(recv_obj, "output_hidden_states") - and len(recv_obj.output_hidden_states[i]) > 0 + and len(recv_obj.output_hidden_states[index]) > 0 ): - meta_info["hidden_states"] = recv_obj.output_hidden_states[i] + meta_info["hidden_states"] = recv_obj.output_hidden_states[index] return meta_info From e452528365ba676f79c8a85df4a82bbbdd4dacb5 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Fri, 21 Feb 2025 17:24:02 +0800 Subject: [PATCH 121/136] bump --- test/srt/test_fragment.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/srt/test_fragment.py b/test/srt/test_fragment.py index 19fb9a185dc..b0884a7da98 100644 --- a/test/srt/test_fragment.py +++ b/test/srt/test_fragment.py @@ -34,7 +34,7 @@ def test_fragment(self): " to spend it outdoors. I decided to take a walk in the nearby park.", " how to improve the performance of my website. I've been doing some research and", " a new user of the platform. I am looking for a new laptop to buy", - " I'm looking for someone to help me with a project.\nI'm a student", + " I'm not sure if you're aware, but I've been trying to get", " the science of numbers and their properties. It is a vast and complex field that", ], ) From 028b54ea1e15d238c532697344f6de48d1f9e5e8 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Fri, 21 Feb 2025 17:31:40 +0800 Subject: [PATCH 122/136] bump --- test/srt/test_fragment.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/srt/test_fragment.py b/test/srt/test_fragment.py index b0884a7da98..5b92e4ce43e 100644 --- a/test/srt/test_fragment.py +++ b/test/srt/test_fragment.py @@ -35,7 +35,7 @@ def test_fragment(self): " how to improve the performance of my website. I've been doing some research and", " a new user of the platform. I am looking for a new laptop to buy", " I'm not sure if you're aware, but I've been trying to get", - " the science of numbers and their properties. It is a vast and complex field that", + " the science of numbers and their properties. It is a branch of mathematics that deals", ], ) From 12a76dbbf44bb4bf4047bfd0c9e2b98f662fa2be Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Fri, 21 Feb 2025 22:11:45 +0800 Subject: [PATCH 123/136] more --- python/sglang/srt/managers/generation_manager.py | 3 +-- python/sglang/srt/orchestration/spmd/orchestrator.py | 7 ++++--- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/python/sglang/srt/managers/generation_manager.py b/python/sglang/srt/managers/generation_manager.py index e7e86ac359f..4cb68a69023 100644 --- a/python/sglang/srt/managers/generation_manager.py +++ b/python/sglang/srt/managers/generation_manager.py @@ -407,9 +407,8 @@ async def tokenize_request( def tokenize_requests( self, - obj: Union[GenerateReqInput, EmbeddingReqInput], + objs: List[Union[GenerateReqInput, EmbeddingReqInput]], ) -> List[Union[TokenizedGenerateReqInput, TokenizedEmbeddingReqInput]]: - objs = [obj[i] for i in range(obj.batch_size)] loop = asyncio.get_event_loop() return loop.run_until_complete( asyncio.gather(*(self.tokenize_request(obj) for obj in objs)) diff --git a/python/sglang/srt/orchestration/spmd/orchestrator.py b/python/sglang/srt/orchestration/spmd/orchestrator.py index 5ae15688726..fd28bbff4cd 100644 --- a/python/sglang/srt/orchestration/spmd/orchestrator.py +++ b/python/sglang/srt/orchestration/spmd/orchestrator.py @@ -27,7 +27,8 @@ def __init__( def generate(self, obj: GenerateReqInput): obj.normalize_batch_and_arguments() - tokenized_requests = self._generation_converter.tokenize_requests(obj) + objs = [obj] if obj.is_single else [obj[i] for i in range(obj.batch_size)] + tokenized_requests = self._generation_converter.tokenize_requests(objs) rid_to_req_index = {r.rid: i for i, r in enumerate(tokenized_requests)} outputs: List[Dict[str, Any]] = [None] * obj.batch_size @@ -39,7 +40,7 @@ def _handle_scheduler_output(batch_token_id_out: BatchTokenIDOut): for output_index in range(len(batch_str_out.rids)): req_index = rid_to_req_index[batch_str_out.rids[output_index]] outputs[req_index] = self._generation_converter.postprocess_response( - batch_str_out, index=output_index, req_obj=obj[req_index] + batch_str_out, index=output_index, req_obj=objs[req_index] ) self._scheduler.on_generation_output = _handle_scheduler_output @@ -50,7 +51,7 @@ def _handle_scheduler_output(batch_token_id_out: BatchTokenIDOut): while self._scheduler.process_batch(): pass - return outputs + return outputs[0] if obj.is_single else outputs def shutdown(self): self._scheduler.shutdown() From 155825295047944d150ec630a8d761b0059e0f26 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Sat, 22 Feb 2025 12:00:32 +0800 Subject: [PATCH 124/136] fmt --- python/sglang/srt/openai_api/adapter.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/python/sglang/srt/openai_api/adapter.py b/python/sglang/srt/openai_api/adapter.py index c589fdf6973..cf257e9c372 100644 --- a/python/sglang/srt/openai_api/adapter.py +++ b/python/sglang/srt/openai_api/adapter.py @@ -938,10 +938,7 @@ def v1_chat_generate_request( if assistant_prefix: encoded = orchestrator.tokenizer.encode(assistant_prefix) - if ( - encoded - and encoded[0] == orchestrator.tokenizer.bos_token_id - ): + if encoded and encoded[0] == orchestrator.tokenizer.bos_token_id: encoded = encoded[1:] prompt_ids += encoded stop = request.stop From bc7f566bec4be498b130fe95c574d2a5755cf849 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Sun, 23 Feb 2025 15:51:59 +0800 Subject: [PATCH 125/136] lint --- python/sglang/srt/orchestration/std/orchestrator.py | 1 + 1 file changed, 1 insertion(+) diff --git a/python/sglang/srt/orchestration/std/orchestrator.py b/python/sglang/srt/orchestration/std/orchestrator.py index 7ec0f761670..cd4a7633c5d 100644 --- a/python/sglang/srt/orchestration/std/orchestrator.py +++ b/python/sglang/srt/orchestration/std/orchestrator.py @@ -26,6 +26,7 @@ import zmq import zmq.asyncio from fastapi import BackgroundTasks + from sglang.srt.aio_rwlock import RWLock from sglang.srt.managers.generation_manager import GenerationManager from sglang.srt.managers.io_struct import ( From 3e45ea118754c1b58e7065d70b08fb50df39a094 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Sun, 23 Feb 2025 16:01:32 +0800 Subject: [PATCH 126/136] add copyright text --- python/sglang/srt/orchestration/std/detokenizer.py | 13 +++++++++++++ python/sglang/srt/orchestration/std/launcher.py | 13 +++++++++++++ python/sglang/srt/orchestration/std/scheduler.py | 13 +++++++++++++ 3 files changed, 39 insertions(+) diff --git a/python/sglang/srt/orchestration/std/detokenizer.py b/python/sglang/srt/orchestration/std/detokenizer.py index 6673280e9ba..443962b9103 100644 --- a/python/sglang/srt/orchestration/std/detokenizer.py +++ b/python/sglang/srt/orchestration/std/detokenizer.py @@ -1,3 +1,16 @@ +# Copyright 2023-2024 SGLang Team +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== import logging import signal diff --git a/python/sglang/srt/orchestration/std/launcher.py b/python/sglang/srt/orchestration/std/launcher.py index 343c30ed611..d422eb01906 100644 --- a/python/sglang/srt/orchestration/std/launcher.py +++ b/python/sglang/srt/orchestration/std/launcher.py @@ -1,3 +1,16 @@ +# Copyright 2023-2024 SGLang Team +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== import asyncio import logging import multiprocessing as mp diff --git a/python/sglang/srt/orchestration/std/scheduler.py b/python/sglang/srt/orchestration/std/scheduler.py index d4623d7cb14..5552ceb2b46 100644 --- a/python/sglang/srt/orchestration/std/scheduler.py +++ b/python/sglang/srt/orchestration/std/scheduler.py @@ -1,3 +1,16 @@ +# Copyright 2023-2024 SGLang Team +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== import faulthandler import logging import os From 6795a90885e9b915b61cd99d15b8979007a73e8c Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Mon, 24 Feb 2025 10:27:26 +0800 Subject: [PATCH 127/136] mv no_grad --- python/sglang/srt/managers/scheduler.py | 2 -- python/sglang/srt/orchestration/std/scheduler.py | 8 +++++--- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index 998828cc8b3..8e2b3f77350 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -402,7 +402,6 @@ def process_batch(self): else: return self._process_batch_normal() - @torch.no_grad() def _process_batch_normal(self): batch = self.get_next_batch_to_run() self.cur_batch = batch @@ -417,7 +416,6 @@ def _process_batch_normal(self): self.last_batch = batch - @torch.no_grad() def _process_batch_overlap(self): batch = self.get_next_batch_to_run() self.cur_batch = batch diff --git a/python/sglang/srt/orchestration/std/scheduler.py b/python/sglang/srt/orchestration/std/scheduler.py index 5552ceb2b46..cd328b5f2c5 100644 --- a/python/sglang/srt/orchestration/std/scheduler.py +++ b/python/sglang/srt/orchestration/std/scheduler.py @@ -20,6 +20,7 @@ import psutil import setproctitle +import torch import zmq from sglang.srt.managers.io_struct import ( @@ -233,9 +234,10 @@ def run_scheduler_process( } ) - while True: - communicator.recv_and_process_input_requests() - scheduler.process_batch() + with torch.no_grad(): + while True: + communicator.recv_and_process_input_requests() + scheduler.process_batch() except Exception: traceback = get_exception_traceback() logger.error(f"Scheduler hit an exception: {traceback}") From 7bd10eeb9cb18a8bfdb5bbc73bd945e37c1d9763 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Mon, 24 Feb 2025 10:30:06 +0800 Subject: [PATCH 128/136] comments --- python/sglang/srt/orchestration/std/launcher.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/sglang/srt/orchestration/std/launcher.py b/python/sglang/srt/orchestration/std/launcher.py index d422eb01906..d963afe7576 100644 --- a/python/sglang/srt/orchestration/std/launcher.py +++ b/python/sglang/srt/orchestration/std/launcher.py @@ -128,7 +128,7 @@ def launch(server_args: ServerArgs) -> Tuple[StdOrchestrator, Dict]: ) detoken_proc.start() - # Launch tokenizer process + # Launch orchestrator process orchestrator = StdOrchestrator(server_args, port_args) if server_args.chat_template: load_chat_template_for_openai_api( From 6eb99b92514d840d1db471e7d06d043f4f52463d Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Mon, 24 Feb 2025 13:09:33 +0800 Subject: [PATCH 129/136] doc --- python/sglang/srt/managers/scheduler.py | 70 +++++++++++++------------ 1 file changed, 36 insertions(+), 34 deletions(-) diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index 8e2b3f77350..8eb22082026 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -27,7 +27,6 @@ import psutil import torch - from sglang.global_config import global_config from sglang.srt.configs.model_config import ModelConfig from sglang.srt.constrained.base_grammar_backend import create_grammar_backend @@ -325,8 +324,8 @@ def __init__( 1.0, ) self.new_token_ratio_decay = ( - self.init_new_token_ratio - self.min_new_token_ratio - ) / global_config.default_new_token_ratio_decay_steps + self.init_new_token_ratio - self.min_new_token_ratio + ) / global_config.default_new_token_ratio_decay_steps self.new_token_ratio = self.init_new_token_ratio # Tells whether the current running batch is full so that we can skip @@ -396,7 +395,8 @@ def watchdog_thread(self): time.sleep(5) self.parent_process.send_signal(signal.SIGQUIT) - def process_batch(self): + def process_batch(self) -> bool: + """Processes a batch and returns whether it has successfully run a batch.""" if self.enable_overlap: return self._process_batch_overlap() else: @@ -415,6 +415,7 @@ def _process_batch_normal(self): self.new_token_ratio = self.init_new_token_ratio self.last_batch = batch + return batch is not None def _process_batch_overlap(self): batch = self.get_next_batch_to_run() @@ -447,6 +448,7 @@ def _process_batch_overlap(self): self.new_token_ratio = self.init_new_token_ratio self.last_batch = batch + return batch is not None def handle_generate_request( self, @@ -619,9 +621,9 @@ def log_prefill_stats( has_being_chunked: bool, ): self.tree_cache_metrics["total"] += ( - adder.log_input_tokens + adder.log_hit_tokens - ) / 10**9 - self.tree_cache_metrics["hit"] += (adder.log_hit_tokens) / 10**9 + adder.log_input_tokens + adder.log_hit_tokens + ) / 10 ** 9 + self.tree_cache_metrics["hit"] += (adder.log_hit_tokens) / 10 ** 9 tree_cache_hit_rate = ( self.tree_cache_metrics["hit"] / self.tree_cache_metrics["total"] ) @@ -806,10 +808,10 @@ def get_new_batch_prefill(self) -> Optional[ScheduleBatch]: if ( self.lora_paths and len( - lora_set - | set([req.lora_path for req in adder.can_run_list]) - | set([req.lora_path]) - ) + lora_set + | set([req.lora_path for req in adder.can_run_list]) + | set([req.lora_path]) + ) > self.max_loras_per_batch ): self.batch_is_full = True @@ -1035,7 +1037,7 @@ def process_batch_result_prefill( if self.is_mixed_chunk and self.enable_overlap and req.finished(): # Free the one delayed token for the mixed decode batch j = len(batch.out_cache_loc) - len(batch.reqs) + i - self.token_to_kv_pool.free(batch.out_cache_loc[j : j + 1]) + self.token_to_kv_pool.free(batch.out_cache_loc[j: j + 1]) continue if req.is_being_chunked <= 0: @@ -1058,10 +1060,10 @@ def process_batch_result_prefill( ): req.hidden_states.append( logits_output.hidden_states[ - hidden_state_offset : ( - hidden_state_offset := hidden_state_offset - + len(req.origin_input_ids) - ) + hidden_state_offset: ( + hidden_state_offset := hidden_state_offset + + len(req.origin_input_ids) + ) ] .cpu() .clone() @@ -1137,7 +1139,7 @@ def process_batch_result_decode( if self.enable_overlap and req.finished(): # Free the one delayed token - self.token_to_kv_pool.free(batch.out_cache_loc[i : i + 1]) + self.token_to_kv_pool.free(batch.out_cache_loc[i: i + 1]) continue if batch.spec_algorithm.is_none(): @@ -1203,15 +1205,15 @@ def add_logprob_return_values( if req.input_token_logprobs_val is None: input_token_logprobs_val = output.input_token_logprobs[ - pt : pt + num_input_logprobs - 1 - req.last_update_decode_tokens - ] + pt: pt + num_input_logprobs - 1 - req.last_update_decode_tokens + ] input_token_logprobs_idx = req.fill_ids[ - len(req.fill_ids) - - num_input_logprobs - + 1 : len(req.fill_ids) - - req.last_update_decode_tokens - ] + len(req.fill_ids) + - num_input_logprobs + + 1: len(req.fill_ids) + - req.last_update_decode_tokens + ] # Clip the padded hash values from image tokens. # Otherwise, it will lead to detokenization errors. input_token_logprobs_idx = [ @@ -1232,18 +1234,18 @@ def add_logprob_return_values( # Some decode tokens are re-computed in an extend batch req.output_token_logprobs_val.extend( output.input_token_logprobs[ - pt - + num_input_logprobs - - 1 - - req.last_update_decode_tokens : pt - + num_input_logprobs - - 1 + pt + + num_input_logprobs + - 1 + - req.last_update_decode_tokens: pt + + num_input_logprobs + - 1 ], ) req.output_token_logprobs_idx.extend( req.fill_ids[ - len(req.fill_ids) - - req.last_update_decode_tokens : len(req.fill_ids) + len(req.fill_ids) + - req.last_update_decode_tokens: len(req.fill_ids) ] ) @@ -1257,10 +1259,10 @@ def add_logprob_return_values( if req.last_update_decode_tokens != 0: req.output_top_logprobs_val.extend( - output.input_top_logprobs_val[i][-req.last_update_decode_tokens :] + output.input_top_logprobs_val[i][-req.last_update_decode_tokens:] ) req.output_top_logprobs_idx.extend( - output.input_top_logprobs_idx[i][-req.last_update_decode_tokens :] + output.input_top_logprobs_idx[i][-req.last_update_decode_tokens:] ) req.output_top_logprobs_val.append(output.next_token_top_logprobs_val[i]) From a2c1d941bc1c237055126f4c25092fd6f517fc06 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Mon, 24 Feb 2025 13:10:14 +0800 Subject: [PATCH 130/136] fmt --- python/sglang/srt/managers/scheduler.py | 65 +++++++++++++------------ 1 file changed, 33 insertions(+), 32 deletions(-) diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index 8eb22082026..fa3b270a706 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -27,6 +27,7 @@ import psutil import torch + from sglang.global_config import global_config from sglang.srt.configs.model_config import ModelConfig from sglang.srt.constrained.base_grammar_backend import create_grammar_backend @@ -324,8 +325,8 @@ def __init__( 1.0, ) self.new_token_ratio_decay = ( - self.init_new_token_ratio - self.min_new_token_ratio - ) / global_config.default_new_token_ratio_decay_steps + self.init_new_token_ratio - self.min_new_token_ratio + ) / global_config.default_new_token_ratio_decay_steps self.new_token_ratio = self.init_new_token_ratio # Tells whether the current running batch is full so that we can skip @@ -621,9 +622,9 @@ def log_prefill_stats( has_being_chunked: bool, ): self.tree_cache_metrics["total"] += ( - adder.log_input_tokens + adder.log_hit_tokens - ) / 10 ** 9 - self.tree_cache_metrics["hit"] += (adder.log_hit_tokens) / 10 ** 9 + adder.log_input_tokens + adder.log_hit_tokens + ) / 10**9 + self.tree_cache_metrics["hit"] += (adder.log_hit_tokens) / 10**9 tree_cache_hit_rate = ( self.tree_cache_metrics["hit"] / self.tree_cache_metrics["total"] ) @@ -808,10 +809,10 @@ def get_new_batch_prefill(self) -> Optional[ScheduleBatch]: if ( self.lora_paths and len( - lora_set - | set([req.lora_path for req in adder.can_run_list]) - | set([req.lora_path]) - ) + lora_set + | set([req.lora_path for req in adder.can_run_list]) + | set([req.lora_path]) + ) > self.max_loras_per_batch ): self.batch_is_full = True @@ -1037,7 +1038,7 @@ def process_batch_result_prefill( if self.is_mixed_chunk and self.enable_overlap and req.finished(): # Free the one delayed token for the mixed decode batch j = len(batch.out_cache_loc) - len(batch.reqs) + i - self.token_to_kv_pool.free(batch.out_cache_loc[j: j + 1]) + self.token_to_kv_pool.free(batch.out_cache_loc[j : j + 1]) continue if req.is_being_chunked <= 0: @@ -1060,10 +1061,10 @@ def process_batch_result_prefill( ): req.hidden_states.append( logits_output.hidden_states[ - hidden_state_offset: ( - hidden_state_offset := hidden_state_offset - + len(req.origin_input_ids) - ) + hidden_state_offset : ( + hidden_state_offset := hidden_state_offset + + len(req.origin_input_ids) + ) ] .cpu() .clone() @@ -1139,7 +1140,7 @@ def process_batch_result_decode( if self.enable_overlap and req.finished(): # Free the one delayed token - self.token_to_kv_pool.free(batch.out_cache_loc[i: i + 1]) + self.token_to_kv_pool.free(batch.out_cache_loc[i : i + 1]) continue if batch.spec_algorithm.is_none(): @@ -1205,15 +1206,15 @@ def add_logprob_return_values( if req.input_token_logprobs_val is None: input_token_logprobs_val = output.input_token_logprobs[ - pt: pt + num_input_logprobs - 1 - req.last_update_decode_tokens - ] + pt : pt + num_input_logprobs - 1 - req.last_update_decode_tokens + ] input_token_logprobs_idx = req.fill_ids[ - len(req.fill_ids) - - num_input_logprobs - + 1: len(req.fill_ids) - - req.last_update_decode_tokens - ] + len(req.fill_ids) + - num_input_logprobs + + 1 : len(req.fill_ids) + - req.last_update_decode_tokens + ] # Clip the padded hash values from image tokens. # Otherwise, it will lead to detokenization errors. input_token_logprobs_idx = [ @@ -1234,18 +1235,18 @@ def add_logprob_return_values( # Some decode tokens are re-computed in an extend batch req.output_token_logprobs_val.extend( output.input_token_logprobs[ - pt - + num_input_logprobs - - 1 - - req.last_update_decode_tokens: pt - + num_input_logprobs - - 1 + pt + + num_input_logprobs + - 1 + - req.last_update_decode_tokens : pt + + num_input_logprobs + - 1 ], ) req.output_token_logprobs_idx.extend( req.fill_ids[ - len(req.fill_ids) - - req.last_update_decode_tokens: len(req.fill_ids) + len(req.fill_ids) + - req.last_update_decode_tokens : len(req.fill_ids) ] ) @@ -1259,10 +1260,10 @@ def add_logprob_return_values( if req.last_update_decode_tokens != 0: req.output_top_logprobs_val.extend( - output.input_top_logprobs_val[i][-req.last_update_decode_tokens:] + output.input_top_logprobs_val[i][-req.last_update_decode_tokens :] ) req.output_top_logprobs_idx.extend( - output.input_top_logprobs_idx[i][-req.last_update_decode_tokens:] + output.input_top_logprobs_idx[i][-req.last_update_decode_tokens :] ) req.output_top_logprobs_val.append(output.next_token_top_logprobs_val[i]) From c531e8cfa59e3e6b18d8d67255f7555f5e962548 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Mon, 24 Feb 2025 13:12:14 +0800 Subject: [PATCH 131/136] comments --- python/sglang/srt/entrypoints/engine_base.py | 13 +++++++++++++ python/sglang/srt/entrypoints/engine_fragment.py | 13 +++++++++++++ .../sglang/srt/orchestration/spmd/orchestrator.py | 13 +++++++++++++ 3 files changed, 39 insertions(+) diff --git a/python/sglang/srt/entrypoints/engine_base.py b/python/sglang/srt/entrypoints/engine_base.py index 7b6eeeac05f..a42a5b7ace2 100644 --- a/python/sglang/srt/entrypoints/engine_base.py +++ b/python/sglang/srt/entrypoints/engine_base.py @@ -1,3 +1,16 @@ +# Copyright 2023-2024 SGLang Team +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== from typing import Dict, Iterator, List, Optional, Union from sglang.srt.managers.io_struct import GenerateReqInput diff --git a/python/sglang/srt/entrypoints/engine_fragment.py b/python/sglang/srt/entrypoints/engine_fragment.py index 49a5cd4fcd2..26eceb09808 100644 --- a/python/sglang/srt/entrypoints/engine_fragment.py +++ b/python/sglang/srt/entrypoints/engine_fragment.py @@ -1,3 +1,16 @@ +# Copyright 2023-2024 SGLang Team +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== from sglang.srt.entrypoints.engine_base import EngineBase from sglang.srt.managers.io_struct import GenerateReqInput from sglang.srt.orchestration.spmd.orchestrator import SpmdOrchestrator diff --git a/python/sglang/srt/orchestration/spmd/orchestrator.py b/python/sglang/srt/orchestration/spmd/orchestrator.py index fd28bbff4cd..fb663b1542a 100644 --- a/python/sglang/srt/orchestration/spmd/orchestrator.py +++ b/python/sglang/srt/orchestration/spmd/orchestrator.py @@ -1,3 +1,16 @@ +# Copyright 2023-2024 SGLang Team +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== from typing import Any, Dict, List from sglang.srt.managers.detokenizer_manager import DetokenizerManager From 6cbbb70d1df80ddd3c462e925bcf734cdee1a531 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Mon, 24 Feb 2025 13:13:12 +0800 Subject: [PATCH 132/136] comments --- python/sglang/srt/entrypoints/engine.py | 4 ---- python/sglang/srt/entrypoints/engine_fragment.py | 2 ++ 2 files changed, 2 insertions(+), 4 deletions(-) diff --git a/python/sglang/srt/entrypoints/engine.py b/python/sglang/srt/entrypoints/engine.py index 4fed37a1d72..f08c0f9db75 100644 --- a/python/sglang/srt/entrypoints/engine.py +++ b/python/sglang/srt/entrypoints/engine.py @@ -90,10 +90,6 @@ def __init__(self, **kwargs): self.scheduler_info = scheduler_info def _generate_impl(self, obj: GenerateReqInput): - """ - The arguments of this function is the same as `sglang/srt/managers/io_struct.py::GenerateReqInput`. - Please refer to `GenerateReqInput` for the documentation. - """ loop = asyncio.get_event_loop() generator = self.orchestrator.generate_request(obj, None) diff --git a/python/sglang/srt/entrypoints/engine_fragment.py b/python/sglang/srt/entrypoints/engine_fragment.py index 26eceb09808..790ca49a688 100644 --- a/python/sglang/srt/entrypoints/engine_fragment.py +++ b/python/sglang/srt/entrypoints/engine_fragment.py @@ -18,6 +18,8 @@ class EngineFragment(EngineBase): + """TODO: Add docstring to describe it.""" + def __init__( self, nccl_port: int, From 613cd9c9912ca07eb0dcc958af05ebc5fda3736c Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Mon, 24 Feb 2025 13:14:11 +0800 Subject: [PATCH 133/136] base method --- python/sglang/srt/entrypoints/engine_base.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/python/sglang/srt/entrypoints/engine_base.py b/python/sglang/srt/entrypoints/engine_base.py index a42a5b7ace2..5f3e2f2181c 100644 --- a/python/sglang/srt/entrypoints/engine_base.py +++ b/python/sglang/srt/entrypoints/engine_base.py @@ -60,3 +60,6 @@ def generate( def _generate_impl(self, obj: GenerateReqInput): raise NotImplementedError + + def shutdown(self): + raise NotImplementedError From 58bec2544d19bbfc646b264cd04cb244ddcc2051 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Mon, 24 Feb 2025 13:14:35 +0800 Subject: [PATCH 134/136] doc --- python/sglang/srt/orchestration/spmd/orchestrator.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/python/sglang/srt/orchestration/spmd/orchestrator.py b/python/sglang/srt/orchestration/spmd/orchestrator.py index fb663b1542a..b27d365d72a 100644 --- a/python/sglang/srt/orchestration/spmd/orchestrator.py +++ b/python/sglang/srt/orchestration/spmd/orchestrator.py @@ -21,6 +21,8 @@ class SpmdOrchestrator: + """TODO: Add docstring to describe it.""" + def __init__( self, server_args: ServerArgs, From dccb7ebb6d64b9e0d989726b01bab0ef6299ebfe Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Mon, 24 Feb 2025 13:15:37 +0800 Subject: [PATCH 135/136] shutdown --- python/sglang/srt/managers/scheduler.py | 67 +++++++++++++------------ 1 file changed, 34 insertions(+), 33 deletions(-) diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index 6c3cffbb479..cd47f16a008 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -27,7 +27,6 @@ import psutil import torch - from sglang.global_config import global_config from sglang.srt.configs.model_config import ModelConfig from sglang.srt.constrained.base_grammar_backend import create_grammar_backend @@ -325,8 +324,8 @@ def __init__( 1.0, ) self.new_token_ratio_decay = ( - self.init_new_token_ratio - self.min_new_token_ratio - ) / global_config.default_new_token_ratio_decay_steps + self.init_new_token_ratio - self.min_new_token_ratio + ) / global_config.default_new_token_ratio_decay_steps self.new_token_ratio = self.init_new_token_ratio # Tells whether the current running batch is full so that we can skip @@ -622,9 +621,9 @@ def log_prefill_stats( has_being_chunked: bool, ): self.tree_cache_metrics["total"] += ( - adder.log_input_tokens + adder.log_hit_tokens - ) / 10**9 - self.tree_cache_metrics["hit"] += (adder.log_hit_tokens) / 10**9 + adder.log_input_tokens + adder.log_hit_tokens + ) / 10 ** 9 + self.tree_cache_metrics["hit"] += (adder.log_hit_tokens) / 10 ** 9 tree_cache_hit_rate = ( self.tree_cache_metrics["hit"] / self.tree_cache_metrics["total"] ) @@ -809,10 +808,10 @@ def get_new_batch_prefill(self) -> Optional[ScheduleBatch]: if ( self.lora_paths and len( - lora_set - | set([req.lora_path for req in adder.can_run_list]) - | set([req.lora_path]) - ) + lora_set + | set([req.lora_path for req in adder.can_run_list]) + | set([req.lora_path]) + ) > self.max_loras_per_batch ): self.batch_is_full = True @@ -1038,7 +1037,7 @@ def process_batch_result_prefill( if self.is_mixed_chunk and self.enable_overlap and req.finished(): # Free the one delayed token for the mixed decode batch j = len(batch.out_cache_loc) - len(batch.reqs) + i - self.token_to_kv_pool.free(batch.out_cache_loc[j : j + 1]) + self.token_to_kv_pool.free(batch.out_cache_loc[j: j + 1]) continue if req.is_being_chunked <= 0: @@ -1061,10 +1060,10 @@ def process_batch_result_prefill( ): req.hidden_states.append( logits_output.hidden_states[ - hidden_state_offset : ( - hidden_state_offset := hidden_state_offset - + len(req.origin_input_ids) - ) + hidden_state_offset: ( + hidden_state_offset := hidden_state_offset + + len(req.origin_input_ids) + ) ] .cpu() .clone() @@ -1140,7 +1139,7 @@ def process_batch_result_decode( if self.enable_overlap and req.finished(): # Free the one delayed token - self.token_to_kv_pool.free(batch.out_cache_loc[i : i + 1]) + self.token_to_kv_pool.free(batch.out_cache_loc[i: i + 1]) continue if batch.spec_algorithm.is_none(): @@ -1206,15 +1205,15 @@ def add_logprob_return_values( if req.input_token_logprobs_val is None: input_token_logprobs_val = output.input_token_logprobs[ - pt : pt + num_input_logprobs - 1 - req.last_update_decode_tokens - ] + pt: pt + num_input_logprobs - 1 - req.last_update_decode_tokens + ] input_token_logprobs_idx = req.fill_ids[ - len(req.fill_ids) - - num_input_logprobs - + 1 : len(req.fill_ids) - - req.last_update_decode_tokens - ] + len(req.fill_ids) + - num_input_logprobs + + 1: len(req.fill_ids) + - req.last_update_decode_tokens + ] # Clip the padded hash values from image tokens. # Otherwise, it will lead to detokenization errors. input_token_logprobs_idx = [ @@ -1235,18 +1234,18 @@ def add_logprob_return_values( # Some decode tokens are re-computed in an extend batch req.output_token_logprobs_val.extend( output.input_token_logprobs[ - pt - + num_input_logprobs - - 1 - - req.last_update_decode_tokens : pt - + num_input_logprobs - - 1 + pt + + num_input_logprobs + - 1 + - req.last_update_decode_tokens: pt + + num_input_logprobs + - 1 ], ) req.output_token_logprobs_idx.extend( req.fill_ids[ - len(req.fill_ids) - - req.last_update_decode_tokens : len(req.fill_ids) + len(req.fill_ids) + - req.last_update_decode_tokens: len(req.fill_ids) ] ) @@ -1260,10 +1259,10 @@ def add_logprob_return_values( if req.last_update_decode_tokens != 0: req.output_top_logprobs_val.extend( - output.input_top_logprobs_val[i][-req.last_update_decode_tokens :] + output.input_top_logprobs_val[i][-req.last_update_decode_tokens:] ) req.output_top_logprobs_idx.extend( - output.input_top_logprobs_idx[i][-req.last_update_decode_tokens :] + output.input_top_logprobs_idx[i][-req.last_update_decode_tokens:] ) req.output_top_logprobs_val.append(output.next_token_top_logprobs_val[i]) @@ -1639,6 +1638,8 @@ def close_session(self, recv_req: CloseSessionReqInput): del self.sessions[session_id] def shutdown(self): + if self.draft_worker: + self.draft_worker.shutdown() self.tp_worker.shutdown() From 17a88740995a6fda91f99ee7a04baa7f918ff0eb Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Mon, 24 Feb 2025 13:34:50 +0800 Subject: [PATCH 136/136] fmt --- python/sglang/srt/managers/scheduler.py | 65 +++++++++++++------------ 1 file changed, 33 insertions(+), 32 deletions(-) diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index cd47f16a008..6c71304020e 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -27,6 +27,7 @@ import psutil import torch + from sglang.global_config import global_config from sglang.srt.configs.model_config import ModelConfig from sglang.srt.constrained.base_grammar_backend import create_grammar_backend @@ -324,8 +325,8 @@ def __init__( 1.0, ) self.new_token_ratio_decay = ( - self.init_new_token_ratio - self.min_new_token_ratio - ) / global_config.default_new_token_ratio_decay_steps + self.init_new_token_ratio - self.min_new_token_ratio + ) / global_config.default_new_token_ratio_decay_steps self.new_token_ratio = self.init_new_token_ratio # Tells whether the current running batch is full so that we can skip @@ -621,9 +622,9 @@ def log_prefill_stats( has_being_chunked: bool, ): self.tree_cache_metrics["total"] += ( - adder.log_input_tokens + adder.log_hit_tokens - ) / 10 ** 9 - self.tree_cache_metrics["hit"] += (adder.log_hit_tokens) / 10 ** 9 + adder.log_input_tokens + adder.log_hit_tokens + ) / 10**9 + self.tree_cache_metrics["hit"] += (adder.log_hit_tokens) / 10**9 tree_cache_hit_rate = ( self.tree_cache_metrics["hit"] / self.tree_cache_metrics["total"] ) @@ -808,10 +809,10 @@ def get_new_batch_prefill(self) -> Optional[ScheduleBatch]: if ( self.lora_paths and len( - lora_set - | set([req.lora_path for req in adder.can_run_list]) - | set([req.lora_path]) - ) + lora_set + | set([req.lora_path for req in adder.can_run_list]) + | set([req.lora_path]) + ) > self.max_loras_per_batch ): self.batch_is_full = True @@ -1037,7 +1038,7 @@ def process_batch_result_prefill( if self.is_mixed_chunk and self.enable_overlap and req.finished(): # Free the one delayed token for the mixed decode batch j = len(batch.out_cache_loc) - len(batch.reqs) + i - self.token_to_kv_pool.free(batch.out_cache_loc[j: j + 1]) + self.token_to_kv_pool.free(batch.out_cache_loc[j : j + 1]) continue if req.is_being_chunked <= 0: @@ -1060,10 +1061,10 @@ def process_batch_result_prefill( ): req.hidden_states.append( logits_output.hidden_states[ - hidden_state_offset: ( - hidden_state_offset := hidden_state_offset - + len(req.origin_input_ids) - ) + hidden_state_offset : ( + hidden_state_offset := hidden_state_offset + + len(req.origin_input_ids) + ) ] .cpu() .clone() @@ -1139,7 +1140,7 @@ def process_batch_result_decode( if self.enable_overlap and req.finished(): # Free the one delayed token - self.token_to_kv_pool.free(batch.out_cache_loc[i: i + 1]) + self.token_to_kv_pool.free(batch.out_cache_loc[i : i + 1]) continue if batch.spec_algorithm.is_none(): @@ -1205,15 +1206,15 @@ def add_logprob_return_values( if req.input_token_logprobs_val is None: input_token_logprobs_val = output.input_token_logprobs[ - pt: pt + num_input_logprobs - 1 - req.last_update_decode_tokens - ] + pt : pt + num_input_logprobs - 1 - req.last_update_decode_tokens + ] input_token_logprobs_idx = req.fill_ids[ - len(req.fill_ids) - - num_input_logprobs - + 1: len(req.fill_ids) - - req.last_update_decode_tokens - ] + len(req.fill_ids) + - num_input_logprobs + + 1 : len(req.fill_ids) + - req.last_update_decode_tokens + ] # Clip the padded hash values from image tokens. # Otherwise, it will lead to detokenization errors. input_token_logprobs_idx = [ @@ -1234,18 +1235,18 @@ def add_logprob_return_values( # Some decode tokens are re-computed in an extend batch req.output_token_logprobs_val.extend( output.input_token_logprobs[ - pt - + num_input_logprobs - - 1 - - req.last_update_decode_tokens: pt - + num_input_logprobs - - 1 + pt + + num_input_logprobs + - 1 + - req.last_update_decode_tokens : pt + + num_input_logprobs + - 1 ], ) req.output_token_logprobs_idx.extend( req.fill_ids[ - len(req.fill_ids) - - req.last_update_decode_tokens: len(req.fill_ids) + len(req.fill_ids) + - req.last_update_decode_tokens : len(req.fill_ids) ] ) @@ -1259,10 +1260,10 @@ def add_logprob_return_values( if req.last_update_decode_tokens != 0: req.output_top_logprobs_val.extend( - output.input_top_logprobs_val[i][-req.last_update_decode_tokens:] + output.input_top_logprobs_val[i][-req.last_update_decode_tokens :] ) req.output_top_logprobs_idx.extend( - output.input_top_logprobs_idx[i][-req.last_update_decode_tokens:] + output.input_top_logprobs_idx[i][-req.last_update_decode_tokens :] ) req.output_top_logprobs_val.append(output.next_token_top_logprobs_val[i])