Skip to content

Commit

Permalink
[Feat] Return hidden states (experimental) (#3364)
Browse files Browse the repository at this point in the history
Co-authored-by: Chayenne <[email protected]>
  • Loading branch information
Jackmin801 and zhaochenyang20 authored Feb 10, 2025
1 parent 2f47d71 commit 5f0e7de
Show file tree
Hide file tree
Showing 12 changed files with 204 additions and 5 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/pr-test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ jobs:
bash scripts/ci_install_dependency.sh
- name: Run test
timeout-minutes: 25
timeout-minutes: 30
run: |
RANGE=${{ matrix.range }}
range_begin=${RANGE%-*}
Expand Down
53 changes: 53 additions & 0 deletions docs/backend/offline_engine_api.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,59 @@
"asyncio.run(main())"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"llm.shutdown()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Return Hidden States"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import sglang as sgl\n",
"\n",
"llm = sgl.Engine(\n",
" model_path=\"meta-llama/Meta-Llama-3.1-8B-Instruct\", return_hidden_states=True\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"prompts = [\n",
" \"Hello, my name is\",\n",
" \"The president of the United States is\",\n",
" \"The capital of France is\",\n",
" \"The future of AI is\",\n",
"]\n",
"\n",
"sampling_params = {\"temperature\": 0.8, \"top_p\": 0.95, \"max_new_tokens\": 10}\n",
"\n",
"outputs = llm.generate(prompts, sampling_params=sampling_params)\n",
"for prompt, output in zip(prompts, outputs):\n",
" print(\"===============================\")\n",
" print(\n",
" f\"Prompt: {prompt}\\nGenerated text: {output['text']}\\nPrompt_Tokens: {output['meta_info']['prompt_tokens']}\\tCompletion_tokens: {output['meta_info']['completion_tokens']}\\nHidden states: {[i.shape for i in output['meta_info']['hidden_states']]}\"\n",
" )\n",
" print()"
]
},
{
"cell_type": "code",
"execution_count": null,
Expand Down
1 change: 1 addition & 0 deletions python/sglang/srt/managers/detokenizer_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,7 @@ def event_loop(self):
input_top_logprobs_idx=recv_obj.input_top_logprobs_idx,
output_top_logprobs_val=recv_obj.output_top_logprobs_val,
output_top_logprobs_idx=recv_obj.output_top_logprobs_idx,
output_hidden_states=recv_obj.output_hidden_states,
)
)

Expand Down
4 changes: 4 additions & 0 deletions python/sglang/srt/managers/io_struct.py
Original file line number Diff line number Diff line change
Expand Up @@ -371,6 +371,8 @@ class BatchTokenIDOut:
output_top_logprobs_val: List[List]
output_top_logprobs_idx: List[List]

output_hidden_states: List[List[float]]


@dataclass
class BatchStrOut:
Expand All @@ -397,6 +399,8 @@ class BatchStrOut:
output_top_logprobs_val: List[List]
output_top_logprobs_idx: List[List]

output_hidden_states: List[List[float]]


@dataclass
class BatchEmbeddingOut:
Expand Down
18 changes: 15 additions & 3 deletions python/sglang/srt/managers/schedule_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -315,6 +315,7 @@ def __init__(
self.output_token_logprobs_val = self.output_token_logprobs_idx = (
self.output_top_logprobs_val
) = self.output_top_logprobs_idx = None
self.hidden_states = []

# Logprobs (internal values)
# The tokens is prefilled but need to be considered as decode tokens
Expand Down Expand Up @@ -604,6 +605,9 @@ class ScheduleBatch:
# Enable custom logit processor
enable_custom_logit_processor: bool = False

# Return hidden states
return_hidden_states: bool = False

@classmethod
def init_new(
cls,
Expand All @@ -615,6 +619,7 @@ def init_new(
enable_overlap: bool,
spec_algorithm: SpeculativeAlgorithm,
enable_custom_logit_processor: bool,
return_hidden_states: bool = False,
):
return cls(
reqs=reqs,
Expand All @@ -629,6 +634,7 @@ def init_new(
device=req_to_token_pool.device,
spec_algorithm=spec_algorithm,
enable_custom_logit_processor=enable_custom_logit_processor,
return_hidden_states=return_hidden_states,
)

def batch_size(self):
Expand Down Expand Up @@ -1196,9 +1202,15 @@ def get_model_worker_batch(self):
spec_algorithm=self.spec_algorithm,
spec_info=self.spec_info,
capture_hidden_mode=(
getattr(self.spec_info, "capture_hidden_mode", CaptureHiddenMode.NULL)
if self.spec_info
else CaptureHiddenMode.NULL
CaptureHiddenMode.FULL
if self.return_hidden_states
else (
getattr(
self.spec_info, "capture_hidden_mode", CaptureHiddenMode.NULL
)
if self.spec_info
else CaptureHiddenMode.NULL
)
),
)

Expand Down
29 changes: 29 additions & 0 deletions python/sglang/srt/managers/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -997,6 +997,7 @@ def get_new_batch_prefill(self) -> Optional[ScheduleBatch]:
self.enable_overlap,
self.spec_algorithm,
self.server_args.enable_custom_logit_processor,
self.server_args.return_hidden_states,
)
new_batch.prepare_for_extend()

Expand Down Expand Up @@ -1156,6 +1157,8 @@ def process_batch_result_prefill(
logits_output.input_token_logprobs.tolist()
)

hidden_state_offset = 0

# Check finish conditions
logprob_pt = 0
for i, (req, next_token_id) in enumerate(zip(batch.reqs, next_token_ids)):
Expand All @@ -1182,6 +1185,21 @@ def process_batch_result_prefill(
i, req, logprob_pt, next_token_ids, logits_output
)

if (
self.server_args.return_hidden_states
and logits_output.hidden_states is not None
):
req.hidden_states.append(
logits_output.hidden_states[
hidden_state_offset : (
hidden_state_offset := hidden_state_offset
+ len(req.origin_input_ids)
)
]
.cpu()
.clone()
)

if req.grammar is not None:
req.grammar.accept_token(next_token_id)
req.grammar.finished = req.finished()
Expand Down Expand Up @@ -1275,6 +1293,12 @@ def process_batch_result_decode(
logits_output.next_token_top_logprobs_idx[i]
)

if (
self.server_args.return_hidden_states
and logits_output.hidden_states is not None
):
req.hidden_states.append(logits_output.hidden_states[i].cpu().clone())

if req.grammar is not None:
req.grammar.accept_token(next_token_id)
req.grammar.finished = req.finished()
Expand Down Expand Up @@ -1398,6 +1422,7 @@ def stream_output(
completion_tokens = []
cached_tokens = []
spec_verify_ct = []
hidden_states = []

if return_logprob:
input_token_logprobs_val = []
Expand Down Expand Up @@ -1464,6 +1489,8 @@ def stream_output(
output_top_logprobs_val.append(req.output_top_logprobs_val)
output_top_logprobs_idx.append(req.output_top_logprobs_idx)

hidden_states.append(req.hidden_states)

# Send to detokenizer
if rids:
self.send_to_detokenizer.send_pyobj(
Expand All @@ -1490,6 +1517,7 @@ def stream_output(
input_top_logprobs_idx,
output_top_logprobs_val,
output_top_logprobs_idx,
hidden_states,
)
)
else: # embedding or reward model
Expand Down Expand Up @@ -1553,6 +1581,7 @@ def get_idle_batch(self):
self.enable_overlap,
self.spec_algorithm,
self.server_args.enable_custom_logit_processor,
self.server_args.return_hidden_states,
)
idle_batch.prepare_for_idle()
return idle_batch
Expand Down
6 changes: 6 additions & 0 deletions python/sglang/srt/managers/tokenizer_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -796,6 +796,12 @@ def _handle_batch_output(
}
)

if (
hasattr(recv_obj, "output_hidden_states")
and len(recv_obj.output_hidden_states[i]) > 0
):
meta_info["hidden_states"] = recv_obj.output_hidden_states[i]

if isinstance(recv_obj, BatchStrOut):
out_dict = {
"text": recv_obj.output_strs[i],
Expand Down
4 changes: 4 additions & 0 deletions python/sglang/srt/managers/tp_worker_overlap_thread.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,10 @@ def forward_thread_func_(self):
logits_output.input_token_logprobs = (
logits_output.input_token_logprobs.to("cpu", non_blocking=True)
)
if logits_output.hidden_states is not None:
logits_output.hidden_states = logits_output.hidden_states.to(
"cpu", non_blocking=True
)
next_token_ids = next_token_ids.to("cpu", non_blocking=True)
copy_done.record()

Expand Down
8 changes: 7 additions & 1 deletion python/sglang/srt/model_executor/cuda_graph_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -349,7 +349,13 @@ def capture_one_batch_size(self, bs: int, forward: Callable):
spec_algorithm=self.model_runner.spec_algorithm,
spec_info=spec_info,
capture_hidden_mode=(
spec_info.capture_hidden_mode if spec_info else CaptureHiddenMode.NULL
CaptureHiddenMode.FULL
if self.model_runner.server_args.return_hidden_states
else (
spec_info.capture_hidden_mode
if spec_info
else CaptureHiddenMode.NULL
)
),
)

Expand Down
6 changes: 6 additions & 0 deletions python/sglang/srt/server_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,7 @@ class ServerArgs:
delete_ckpt_after_loading: bool = False
enable_memory_saver: bool = False
allow_auto_truncate: bool = False
return_hidden_states: bool = False

# Custom logit processor
enable_custom_logit_processor: bool = False
Expand Down Expand Up @@ -896,6 +897,11 @@ def add_cli_args(parser: argparse.ArgumentParser):
action="store_true",
help="Enable users to pass custom logit processors to the server (disabled by default for security)",
)
parser.add_argument(
"--return-hidden-states",
action="store_true",
help="Return hidden states in the response.",
)
# Function Calling
parser.add_argument(
"--tool-call-parser",
Expand Down
1 change: 1 addition & 0 deletions test/srt/run_suite.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
"test_torchao.py",
"test_triton_attention_kernels.py",
"test_triton_attention_backend.py",
"test_hidden_states.py",
"test_update_weights_from_disk.py",
"test_update_weights_from_tensor.py",
"test_vision_chunked_prefill.py",
Expand Down
77 changes: 77 additions & 0 deletions test/srt/test_hidden_states.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
import unittest

import torch
from transformers import AutoModelForCausalLM, AutoTokenizer

import sglang as sgl
from sglang.test.test_utils import is_in_ci


class TestHiddenState(unittest.TestCase):
def test_return_hidden_states(self):
prompts = ["Today is", "Today is a sunny day and I like"]
model_path = "meta-llama/Meta-Llama-3.1-8B-Instruct"
tokenizer = AutoTokenizer.from_pretrained(model_path)
input_ids = tokenizer(prompts).input_ids

sampling_params = {"temperature": 0, "max_new_tokens": 8}

engine = sgl.Engine(
model_path=model_path,
random_seed=42,
return_hidden_states=True,
skip_tokenizer_init=True,
)
outputs = engine.generate(input_ids=input_ids, sampling_params=sampling_params)
engine.shutdown()

for output in outputs:
self.assertEqual(len(output["meta_info"]["hidden_states"]), 8)
for hidden_state in output["meta_info"]["hidden_states"]:
self.assertIsInstance(hidden_state, torch.Tensor)
# Checks that splicing of the batch was done correctly
self.assertGreater(
outputs[1]["meta_info"]["hidden_states"][0].shape[0],
outputs[0]["meta_info"]["hidden_states"][0].shape[0],
)

model = AutoModelForCausalLM.from_pretrained(
model_path, torch_dtype=torch.bfloat16, device_map="cuda"
)

for input_id, output in zip(input_ids, outputs):
with torch.inference_mode():
hf_out = model(
torch.tensor(
[input_id + output["token_ids"][:-1]], device=model.device
),
output_hidden_states=True,
)
print("=== HF Hiddens ===")
print(hf_out["hidden_states"][-1][0])
sg_hidden_states = torch.cat(
[
i.unsqueeze(0) if len(i.shape) == 1 else i
for i in output["meta_info"]["hidden_states"]
]
).to("cuda")
print("=== SRT Hiddens ===")
print(sg_hidden_states)

print(
f"Max diff: {torch.max(torch.abs(hf_out['hidden_states'][-1][0] - sg_hidden_states))}"
)

atol = 0.8 if is_in_ci() else 0.4
self.assertTrue(
torch.allclose(
hf_out["hidden_states"][-1][0],
sg_hidden_states,
atol=atol,
rtol=0,
)
)


if __name__ == "__main__":
unittest.main()

0 comments on commit 5f0e7de

Please sign in to comment.