Skip to content

Commit

Permalink
Merge branch 'main' into mmmu
Browse files Browse the repository at this point in the history
  • Loading branch information
zhaochenyang20 authored Feb 22, 2025
2 parents 15c25d1 + 9087694 commit cc497c9
Show file tree
Hide file tree
Showing 9 changed files with 128 additions and 101 deletions.
213 changes: 112 additions & 101 deletions python/sglang/srt/managers/detokenizer_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,11 @@
)
from sglang.srt.server_args import PortArgs, ServerArgs
from sglang.srt.utils import configure_logger, get_zmq_socket
from sglang.utils import find_printable_text, get_exception_traceback
from sglang.utils import (
TypeBasedDispatcher,
find_printable_text,
get_exception_traceback,
)

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -83,6 +87,13 @@ def __init__(

self.decode_status = LimitedCapacityDict(capacity=DETOKENIZER_MAX_STATES)

self._request_dispatcher = TypeBasedDispatcher(
[
(BatchEmbeddingOut, self.handle_batch_embedding_out),
(BatchTokenIDOut, self.handle_batch_token_id_out),
]
)

def trim_matched_stop(
self, output: Union[str, List[int]], finished_reason: Dict, no_stop_trim: bool
):
Expand Down Expand Up @@ -111,115 +122,115 @@ def event_loop(self):

while True:
recv_obj = self.recv_from_scheduler.recv_pyobj()

if isinstance(recv_obj, BatchEmbeddingOut):
# If it is embedding model, no detokenization is needed.
self.send_to_tokenizer.send_pyobj(recv_obj)
continue
output = self._request_dispatcher(recv_obj)
self.send_to_tokenizer.send_pyobj(output)

def handle_batch_embedding_out(self, recv_obj: BatchEmbeddingOut):
# If it is embedding model, no detokenization is needed.
return recv_obj

def handle_batch_token_id_out(self, recv_obj: BatchTokenIDOut):
bs = len(recv_obj.rids)

# Initialize decode status
read_ids, surr_ids = [], []
for i in range(bs):
rid = recv_obj.rids[i]
vid = recv_obj.vids[i]
if rid not in self.decode_status or self.decode_status[rid].vid != vid:
s = DecodeStatus(
vid=vid,
decoded_text=recv_obj.decoded_texts[i],
decode_ids=recv_obj.decode_ids[i],
surr_offset=0,
read_offset=recv_obj.read_offsets[i],
)
self.decode_status[rid] = s
else:
assert isinstance(recv_obj, BatchTokenIDOut)

bs = len(recv_obj.rids)

# Initialize decode status
read_ids, surr_ids = [], []
for i in range(bs):
rid = recv_obj.rids[i]
vid = recv_obj.vids[i]
if rid not in self.decode_status or self.decode_status[rid].vid != vid:
s = DecodeStatus(
vid=vid,
decoded_text=recv_obj.decoded_texts[i],
decode_ids=recv_obj.decode_ids[i],
surr_offset=0,
read_offset=recv_obj.read_offsets[i],
)
self.decode_status[rid] = s
else:
s = self.decode_status[rid]
s.decode_ids = recv_obj.decode_ids[i]

read_ids.append(
self.trim_matched_stop(
s.decode_ids[s.surr_offset :],
recv_obj.finished_reasons[i],
recv_obj.no_stop_trim[i],
)
s = self.decode_status[rid]
s.decode_ids = recv_obj.decode_ids[i]

read_ids.append(
self.trim_matched_stop(
s.decode_ids[s.surr_offset :],
recv_obj.finished_reasons[i],
recv_obj.no_stop_trim[i],
)
surr_ids.append(s.decode_ids[s.surr_offset : s.read_offset])

# TODO(lmzheng): handle skip_special_tokens/spaces_between_special_tokens per request
surr_texts = self.tokenizer.batch_decode(
surr_ids,
skip_special_tokens=recv_obj.skip_special_tokens[0],
spaces_between_special_tokens=recv_obj.spaces_between_special_tokens[0],
)
read_texts = self.tokenizer.batch_decode(
read_ids,
skip_special_tokens=recv_obj.skip_special_tokens[0],
spaces_between_special_tokens=recv_obj.spaces_between_special_tokens[0],
)
surr_ids.append(s.decode_ids[s.surr_offset : s.read_offset])

# Incremental decoding
output_strs = []
finished_reqs = []
for i in range(bs):
try:
s = self.decode_status[recv_obj.rids[i]]
except KeyError:
raise RuntimeError(
f"Decode status not found for request {recv_obj.rids[i]}. "
"It may be due to the request being evicted from the decode status due to memory pressure. "
"Please increase the maximum number of requests by setting "
"the SGLANG_DETOKENIZER_MAX_STATES environment variable to a bigger value than the default value. "
f"The current value is {DETOKENIZER_MAX_STATES}. "
"For more details, see: https://github.com/sgl-project/sglang/issues/2812"
)
new_text = read_texts[i][len(surr_texts[i]) :]
if recv_obj.finished_reasons[i] is None:
# Streaming chunk: update the decode status
if len(new_text) > 0 and not new_text.endswith("�"):
s.decoded_text = s.decoded_text + new_text
s.surr_offset = s.read_offset
s.read_offset = len(s.decode_ids)
new_text = ""
else:
new_text = find_printable_text(new_text)
else:
finished_reqs.append(recv_obj.rids[i])

output_strs.append(
self.trim_matched_stop(
s.decoded_text + new_text,
recv_obj.finished_reasons[i],
recv_obj.no_stop_trim[i],
)
# TODO(lmzheng): handle skip_special_tokens/spaces_between_special_tokens per request
surr_texts = self.tokenizer.batch_decode(
surr_ids,
skip_special_tokens=recv_obj.skip_special_tokens[0],
spaces_between_special_tokens=recv_obj.spaces_between_special_tokens[0],
)
read_texts = self.tokenizer.batch_decode(
read_ids,
skip_special_tokens=recv_obj.skip_special_tokens[0],
spaces_between_special_tokens=recv_obj.spaces_between_special_tokens[0],
)

# Incremental decoding
output_strs = []
finished_reqs = []
for i in range(bs):
try:
s = self.decode_status[recv_obj.rids[i]]
except KeyError:
raise RuntimeError(
f"Decode status not found for request {recv_obj.rids[i]}. "
"It may be due to the request being evicted from the decode status due to memory pressure. "
"Please increase the maximum number of requests by setting "
"the SGLANG_DETOKENIZER_MAX_STATES environment variable to a bigger value than the default value. "
f"The current value is {DETOKENIZER_MAX_STATES}. "
"For more details, see: https://github.com/sgl-project/sglang/issues/2812"
)
new_text = read_texts[i][len(surr_texts[i]) :]
if recv_obj.finished_reasons[i] is None:
# Streaming chunk: update the decode status
if len(new_text) > 0 and not new_text.endswith("�"):
s.decoded_text = s.decoded_text + new_text
s.surr_offset = s.read_offset
s.read_offset = len(s.decode_ids)
new_text = ""
else:
new_text = find_printable_text(new_text)
else:
finished_reqs.append(recv_obj.rids[i])

self.send_to_tokenizer.send_pyobj(
BatchStrOut(
rids=recv_obj.rids,
finished_reasons=recv_obj.finished_reasons,
output_strs=output_strs,
prompt_tokens=recv_obj.prompt_tokens,
completion_tokens=recv_obj.completion_tokens,
cached_tokens=recv_obj.cached_tokens,
spec_verify_ct=recv_obj.spec_verify_ct,
input_token_logprobs_val=recv_obj.input_token_logprobs_val,
input_token_logprobs_idx=recv_obj.input_token_logprobs_idx,
output_token_logprobs_val=recv_obj.output_token_logprobs_val,
output_token_logprobs_idx=recv_obj.output_token_logprobs_idx,
input_top_logprobs_val=recv_obj.input_top_logprobs_val,
input_top_logprobs_idx=recv_obj.input_top_logprobs_idx,
output_top_logprobs_val=recv_obj.output_top_logprobs_val,
output_top_logprobs_idx=recv_obj.output_top_logprobs_idx,
output_hidden_states=recv_obj.output_hidden_states,
output_strs.append(
self.trim_matched_stop(
s.decoded_text + new_text,
recv_obj.finished_reasons[i],
recv_obj.no_stop_trim[i],
)
)

# remove decodestatus for completed requests
for rid in finished_reqs:
self.decode_status.pop(rid)
out = BatchStrOut(
rids=recv_obj.rids,
finished_reasons=recv_obj.finished_reasons,
output_strs=output_strs,
prompt_tokens=recv_obj.prompt_tokens,
completion_tokens=recv_obj.completion_tokens,
cached_tokens=recv_obj.cached_tokens,
spec_verify_ct=recv_obj.spec_verify_ct,
input_token_logprobs_val=recv_obj.input_token_logprobs_val,
input_token_logprobs_idx=recv_obj.input_token_logprobs_idx,
output_token_logprobs_val=recv_obj.output_token_logprobs_val,
output_token_logprobs_idx=recv_obj.output_token_logprobs_idx,
input_top_logprobs_val=recv_obj.input_top_logprobs_val,
input_top_logprobs_idx=recv_obj.input_top_logprobs_idx,
output_top_logprobs_val=recv_obj.output_top_logprobs_val,
output_top_logprobs_idx=recv_obj.output_top_logprobs_idx,
output_hidden_states=recv_obj.output_hidden_states,
)

# remove decodestatus for completed requests
for rid in finished_reqs:
self.decode_status.pop(rid)

return out


class LimitedCapacityDict(OrderedDict):
Expand Down
2 changes: 2 additions & 0 deletions python/sglang/srt/models/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -458,6 +458,8 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
continue
if name.startswith("model.vision_tower") and name not in params_dict:
continue
if self.config.tie_word_embeddings and "lm_head.weight" in name:
continue
# Handle FP8 kv-scale remapping
if "scale" in name:
name = maybe_remap_kv_scale_name(name, params_dict)
Expand Down
2 changes: 2 additions & 0 deletions python/sglang/srt/models/minicpm.py
Original file line number Diff line number Diff line change
Expand Up @@ -339,6 +339,8 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
# Models trained using ColossalAI may include these tensors in
# the checkpoint. Skip them.
continue
if self.config.tie_word_embeddings and "lm_head.weight" in name:
continue

for param_name, weight_name, shard_id in stacked_params_mapping:
if weight_name not in name:
Expand Down
2 changes: 2 additions & 0 deletions python/sglang/srt/models/minicpm3.py
Original file line number Diff line number Diff line change
Expand Up @@ -603,6 +603,8 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
# Models trained using ColossalAI may include these tensors in
# the checkpoint. Skip them.
continue
if self.config.tie_word_embeddings and "lm_head.weight" in name:
continue

for param_name, weight_name, shard_id in stacked_params_mapping:
if weight_name not in name:
Expand Down
2 changes: 2 additions & 0 deletions python/sglang/srt/models/olmo.py
Original file line number Diff line number Diff line change
Expand Up @@ -325,6 +325,8 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
# Models trained using ColossalAI may include these tensors in
# the checkpoint. Skip them.
continue
if self.config.tie_word_embeddings and "lm_head.weight" in name:
continue
for param_name, weight_name, shard_id in stacked_params_mapping:
if weight_name not in name:
continue
Expand Down
2 changes: 2 additions & 0 deletions python/sglang/srt/models/phi3_small.py
Original file line number Diff line number Diff line change
Expand Up @@ -433,6 +433,8 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
continue
if name.endswith(".bias") and name not in params_dict:
continue
if self.config.tie_word_embeddings and "lm_head.weight" in name:
continue

param = params_dict[name]
weight_loader = getattr(param, "weight_loader", default_weight_loader)
Expand Down
2 changes: 2 additions & 0 deletions python/sglang/srt/models/qwen2.py
Original file line number Diff line number Diff line change
Expand Up @@ -377,6 +377,8 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
# Models trained using ColossalAI may include these tensors in
# the checkpoint. Skip them.
continue
if self.config.tie_word_embeddings and "lm_head.weight" in name:
continue
if name.startswith("model.vision_tower") and name not in params_dict:
continue

Expand Down
2 changes: 2 additions & 0 deletions python/sglang/srt/models/qwen2_vl.py
Original file line number Diff line number Diff line change
Expand Up @@ -586,6 +586,8 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
for name, loaded_weight in weights:
if "rotary_emb.inv_freq" in name:
continue
if self.config.tie_word_embeddings and "lm_head.weight" in name:
continue

for param_name, weight_name, shard_id in stacked_params_mapping:
if weight_name not in name:
Expand Down
2 changes: 2 additions & 0 deletions python/sglang/srt/models/torch_native_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -486,6 +486,8 @@ def load_weights_to_module(
continue
if name.startswith("model.vision_tower") and name not in params_dict:
continue
if self.config.tie_word_embeddings and "lm_head.weight" in name:
continue

for param_name, weight_name, shard_id in stacked_params_mapping:
if weight_name not in name:
Expand Down

0 comments on commit cc497c9

Please sign in to comment.