Skip to content

Commit

Permalink
Revert "Rename TokenizerManager to StdOrchestrator" (#3828)
Browse files Browse the repository at this point in the history
  • Loading branch information
merrymercy authored Feb 24, 2025
1 parent c9745ee commit f2388f6
Show file tree
Hide file tree
Showing 11 changed files with 130 additions and 116 deletions.
2 changes: 1 addition & 1 deletion docs/backend/function_calling.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -426,7 +426,7 @@
"from sglang.srt.managers.io_struct import Tool, Function\n",
"\n",
"llm = sgl.Engine(model_path=\"meta-llama/Meta-Llama-3.1-8B-Instruct\")\n",
"tokenizer = llm.orchestrator.tokenizer\n",
"tokenizer = llm.tokenizer_manager.tokenizer\n",
"input_ids = tokenizer.apply_chat_template(\n",
" messages, tokenize=True, add_generation_prompt=True, tools=tools\n",
")\n",
Expand Down
50 changes: 27 additions & 23 deletions python/sglang/srt/entrypoints/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,8 @@
UpdateWeightsFromTensorReqInput,
)
from sglang.srt.managers.scheduler import run_scheduler_process
from sglang.srt.managers.tokenizer_manager import TokenizerManager
from sglang.srt.openai_api.adapter import load_chat_template_for_openai_api
from sglang.srt.orchestration.std.orchestrator import StdOrchestrator
from sglang.srt.server_args import PortArgs, ServerArgs
from sglang.srt.torch_memory_saver_adapter import TorchMemorySaverAdapter
from sglang.srt.utils import (
Expand All @@ -74,12 +74,12 @@ class Engine:
The entry point to the inference engine.
- The engine consists of three components:
1. StdOrchestrator: Tokenizes the requests and sends them to the scheduler.
1. TokenizerManager: Tokenizes the requests and sends them to the scheduler.
2. Scheduler (subprocess): Receives requests from the Tokenizer Manager, schedules batches, forwards them, and sends the output tokens to the Detokenizer Manager.
3. DetokenizerManager (subprocess): Detokenizes the output tokens and sends the result back to the Tokenizer Manager.
Note:
1. The HTTP server, Engine, and StdOrchestrator both run in the main process.
1. The HTTP server, Engine, and TokenizerManager both run in the main process.
2. Inter-process communication is done through ICP (each process uses a different port) via the ZMQ library.
"""

Expand All @@ -102,8 +102,10 @@ def __init__(self, **kwargs):
atexit.register(self.shutdown)

# Launch subprocesses
orchestrator, scheduler_info = _launch_subprocesses(server_args=server_args)
self.orchestrator = orchestrator
tokenizer_manager, scheduler_info = _launch_subprocesses(
server_args=server_args
)
self.tokenizer_manager = tokenizer_manager
self.scheduler_info = scheduler_info

def generate(
Expand Down Expand Up @@ -145,7 +147,7 @@ def generate(
stream=stream,
)
loop = asyncio.get_event_loop()
generator = self.orchestrator.generate_request(obj, None)
generator = self.tokenizer_manager.generate_request(obj, None)

if stream:

Expand Down Expand Up @@ -195,7 +197,7 @@ async def async_generate(
stream=stream,
custom_logit_processor=custom_logit_processor,
)
generator = self.orchestrator.generate_request(obj, None)
generator = self.tokenizer_manager.generate_request(obj, None)

if stream is True:
return generator
Expand All @@ -213,7 +215,7 @@ def encode(

obj = EmbeddingReqInput(text=prompt)
loop = asyncio.get_event_loop()
generator = self.orchestrator.generate_request(obj, None)
generator = self.tokenizer_manager.generate_request(obj, None)
ret = loop.run_until_complete(generator.__anext__())
return ret

Expand All @@ -222,14 +224,14 @@ def shutdown(self):
kill_process_tree(os.getpid(), include_parent=False)

def start_profile(self):
self.orchestrator.start_profile()
self.tokenizer_manager.start_profile()

def stop_profile(self):
self.orchestrator.stop_profile()
self.tokenizer_manager.stop_profile()

def get_server_info(self):
return {
**dataclasses.asdict(self.orchestrator.server_args), # server args
**dataclasses.asdict(self.tokenizer_manager.server_args), # server args
**self.scheduler_info,
"version": __version__,
}
Expand All @@ -254,7 +256,7 @@ def init_weights_update_group(
)
loop = asyncio.get_event_loop()
return loop.run_until_complete(
self.orchestrator.init_weights_update_group(obj, None)
self.tokenizer_manager.init_weights_update_group(obj, None)
)

def update_weights_from_distributed(self, name: str, dtype, shape):
Expand All @@ -266,7 +268,7 @@ def update_weights_from_distributed(self, name: str, dtype, shape):
)
loop = asyncio.get_event_loop()
return loop.run_until_complete(
self.orchestrator.update_weights_from_distributed(obj, None)
self.tokenizer_manager.update_weights_from_distributed(obj, None)
)

def update_weights_from_tensor(self, named_tensors: List[Tuple[str, torch.Tensor]]):
Expand All @@ -276,29 +278,31 @@ def update_weights_from_tensor(self, named_tensors: List[Tuple[str, torch.Tensor
)
loop = asyncio.get_event_loop()
return loop.run_until_complete(
self.orchestrator.update_weights_from_tensor(obj, None)
self.tokenizer_manager.update_weights_from_tensor(obj, None)
)

def get_weights_by_name(self, name: str, truncate_size: int = 100):
"""Get weights by parameter name."""
obj = GetWeightsByNameReqInput(name=name, truncate_size=truncate_size)
loop = asyncio.get_event_loop()
return loop.run_until_complete(self.orchestrator.get_weights_by_name(obj, None))
return loop.run_until_complete(
self.tokenizer_manager.get_weights_by_name(obj, None)
)

def release_memory_occupation(self):
"""Release GPU occupation temporarily."""
obj = ReleaseMemoryOccupationReqInput()
loop = asyncio.get_event_loop()
return loop.run_until_complete(
self.orchestrator.release_memory_occupation(obj, None)
self.tokenizer_manager.release_memory_occupation(obj, None)
)

def resume_memory_occupation(self):
"""Resume GPU occupation."""
obj = ResumeMemoryOccupationReqInput()
loop = asyncio.get_event_loop()
return loop.run_until_complete(
self.orchestrator.resume_memory_occupation(obj, None)
self.tokenizer_manager.resume_memory_occupation(obj, None)
)


Expand Down Expand Up @@ -347,9 +351,9 @@ def sigquit_handler(signum, frame):
mp.set_start_method("spawn", force=True)


def _launch_subprocesses(server_args: ServerArgs) -> Tuple[StdOrchestrator, Dict]:
def _launch_subprocesses(server_args: ServerArgs) -> Tuple[TokenizerManager, Dict]:
"""
Launch the StdOrchestrator in the main process, the Scheduler in a subprocess, and the DetokenizerManager in another subprocess.
Launch the TokenizerManager in the main process, the Scheduler in a subprocess, and the DetokenizerManager in another subprocess.
"""
# Configure global environment
configure_logger(server_args)
Expand Down Expand Up @@ -432,10 +436,10 @@ def _launch_subprocesses(server_args: ServerArgs) -> Tuple[StdOrchestrator, Dict
detoken_proc.start()

# Launch tokenizer process
orchestrator = StdOrchestrator(server_args, port_args)
tokenizer_manager = TokenizerManager(server_args, port_args)
if server_args.chat_template:
load_chat_template_for_openai_api(
orchestrator, server_args.chat_template, server_args.model_path
tokenizer_manager, server_args.chat_template, server_args.model_path
)

# Wait for the model to finish loading
Expand All @@ -459,5 +463,5 @@ def _launch_subprocesses(server_args: ServerArgs) -> Tuple[StdOrchestrator, Dict

# Assume all schedulers have the same scheduler_info
scheduler_info = scheduler_infos[0]
orchestrator.configure_max_req_input_len(scheduler_info["max_req_input_len"])
return orchestrator, scheduler_info
tokenizer_manager.configure_max_req_input_len(scheduler_info["max_req_input_len"])
return tokenizer_manager, scheduler_info
Loading

0 comments on commit f2388f6

Please sign in to comment.