Skip to content

Commit

Permalink
shutdown
Browse files Browse the repository at this point in the history
  • Loading branch information
fzyzcjy committed Feb 24, 2025
1 parent 58bec25 commit dccb7eb
Showing 1 changed file with 34 additions and 33 deletions.
67 changes: 34 additions & 33 deletions python/sglang/srt/managers/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@

import psutil
import torch

from sglang.global_config import global_config
from sglang.srt.configs.model_config import ModelConfig
from sglang.srt.constrained.base_grammar_backend import create_grammar_backend
Expand Down Expand Up @@ -325,8 +324,8 @@ def __init__(
1.0,
)
self.new_token_ratio_decay = (
self.init_new_token_ratio - self.min_new_token_ratio
) / global_config.default_new_token_ratio_decay_steps
self.init_new_token_ratio - self.min_new_token_ratio
) / global_config.default_new_token_ratio_decay_steps
self.new_token_ratio = self.init_new_token_ratio

# Tells whether the current running batch is full so that we can skip
Expand Down Expand Up @@ -622,9 +621,9 @@ def log_prefill_stats(
has_being_chunked: bool,
):
self.tree_cache_metrics["total"] += (
adder.log_input_tokens + adder.log_hit_tokens
) / 10**9
self.tree_cache_metrics["hit"] += (adder.log_hit_tokens) / 10**9
adder.log_input_tokens + adder.log_hit_tokens
) / 10 ** 9
self.tree_cache_metrics["hit"] += (adder.log_hit_tokens) / 10 ** 9
tree_cache_hit_rate = (
self.tree_cache_metrics["hit"] / self.tree_cache_metrics["total"]
)
Expand Down Expand Up @@ -809,10 +808,10 @@ def get_new_batch_prefill(self) -> Optional[ScheduleBatch]:
if (
self.lora_paths
and len(
lora_set
| set([req.lora_path for req in adder.can_run_list])
| set([req.lora_path])
)
lora_set
| set([req.lora_path for req in adder.can_run_list])
| set([req.lora_path])
)
> self.max_loras_per_batch
):
self.batch_is_full = True
Expand Down Expand Up @@ -1038,7 +1037,7 @@ def process_batch_result_prefill(
if self.is_mixed_chunk and self.enable_overlap and req.finished():
# Free the one delayed token for the mixed decode batch
j = len(batch.out_cache_loc) - len(batch.reqs) + i
self.token_to_kv_pool.free(batch.out_cache_loc[j : j + 1])
self.token_to_kv_pool.free(batch.out_cache_loc[j: j + 1])
continue

if req.is_being_chunked <= 0:
Expand All @@ -1061,10 +1060,10 @@ def process_batch_result_prefill(
):
req.hidden_states.append(
logits_output.hidden_states[
hidden_state_offset : (
hidden_state_offset := hidden_state_offset
+ len(req.origin_input_ids)
)
hidden_state_offset: (
hidden_state_offset := hidden_state_offset
+ len(req.origin_input_ids)
)
]
.cpu()
.clone()
Expand Down Expand Up @@ -1140,7 +1139,7 @@ def process_batch_result_decode(

if self.enable_overlap and req.finished():
# Free the one delayed token
self.token_to_kv_pool.free(batch.out_cache_loc[i : i + 1])
self.token_to_kv_pool.free(batch.out_cache_loc[i: i + 1])
continue

if batch.spec_algorithm.is_none():
Expand Down Expand Up @@ -1206,15 +1205,15 @@ def add_logprob_return_values(

if req.input_token_logprobs_val is None:
input_token_logprobs_val = output.input_token_logprobs[
pt : pt + num_input_logprobs - 1 - req.last_update_decode_tokens
]
pt: pt + num_input_logprobs - 1 - req.last_update_decode_tokens
]

input_token_logprobs_idx = req.fill_ids[
len(req.fill_ids)
- num_input_logprobs
+ 1 : len(req.fill_ids)
- req.last_update_decode_tokens
]
len(req.fill_ids)
- num_input_logprobs
+ 1: len(req.fill_ids)
- req.last_update_decode_tokens
]
# Clip the padded hash values from image tokens.
# Otherwise, it will lead to detokenization errors.
input_token_logprobs_idx = [
Expand All @@ -1235,18 +1234,18 @@ def add_logprob_return_values(
# Some decode tokens are re-computed in an extend batch
req.output_token_logprobs_val.extend(
output.input_token_logprobs[
pt
+ num_input_logprobs
- 1
- req.last_update_decode_tokens : pt
+ num_input_logprobs
- 1
pt
+ num_input_logprobs
- 1
- req.last_update_decode_tokens: pt
+ num_input_logprobs
- 1
],
)
req.output_token_logprobs_idx.extend(
req.fill_ids[
len(req.fill_ids)
- req.last_update_decode_tokens : len(req.fill_ids)
len(req.fill_ids)
- req.last_update_decode_tokens: len(req.fill_ids)
]
)

Expand All @@ -1260,10 +1259,10 @@ def add_logprob_return_values(

if req.last_update_decode_tokens != 0:
req.output_top_logprobs_val.extend(
output.input_top_logprobs_val[i][-req.last_update_decode_tokens :]
output.input_top_logprobs_val[i][-req.last_update_decode_tokens:]
)
req.output_top_logprobs_idx.extend(
output.input_top_logprobs_idx[i][-req.last_update_decode_tokens :]
output.input_top_logprobs_idx[i][-req.last_update_decode_tokens:]
)

req.output_top_logprobs_val.append(output.next_token_top_logprobs_val[i])
Expand Down Expand Up @@ -1639,6 +1638,8 @@ def close_session(self, recv_req: CloseSessionReqInput):
del self.sessions[session_id]

def shutdown(self):
if self.draft_worker:
self.draft_worker.shutdown()
self.tp_worker.shutdown()


Expand Down

0 comments on commit dccb7eb

Please sign in to comment.