Skip to content

Commit

Permalink
feature (vllm): add proper traceability for the vllm sampled tokens
Browse files Browse the repository at this point in the history
  • Loading branch information
AdamBelfki3 committed Dec 17, 2024
1 parent 4f67ca4 commit 9c8d7eb
Show file tree
Hide file tree
Showing 3 changed files with 64 additions and 13 deletions.
14 changes: 13 additions & 1 deletion src/nnsight/modeling/vllm/model_runners/GPUModelRunner.py
Original file line number Diff line number Diff line change
Expand Up @@ -311,7 +311,19 @@ def inner():
logits=logits,
sampling_metadata=model_input.sampling_metadata,
)
output.sampled_token_ids = self.model.tokens(output.sampled_token_ids)

og_sample_tokens = torch.tensor([token.samples[0].output_token for token in output.outputs])

with Patcher(patches):
sample_tokens = self.model.samples(og_sample_tokens)

# inject any changes to the sampled tokens
for idx, seq_out in enumerate(output.outputs):
sample = seq_out.samples[0]
sample.output_token = sample_tokens[idx].item()
logprob = sample.logprobs.pop(og_sample_tokens[idx].item())
sample.logprobs[sample_tokens[idx].item()] = logprob

if (
self.observability_config is not None
and self.observability_config.collect_model_forward_time
Expand Down
4 changes: 2 additions & 2 deletions src/nnsight/modeling/vllm/vllm.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ class VLLM(RemoteableMixin):
- vllm_entrypoint (vllm.LLM): vLLM language model.
- tokenizer (vllm.transformers_utils.tokenizer.AnyTokenizer): tokenizer.
- logits (nnsight.WrapperModule): logits.
- tokens (nnsight.WrapperModule): tokens.
- samples (nnsight.WrapperModule): sampled tokens.
.. code-block:: python
from nnsight.models.VLLM import VLLM
Expand Down Expand Up @@ -67,7 +67,7 @@ def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)

self.logits: WrapperModule = WrapperModule()
self.tokens: WrapperModule = WrapperModule()
self.samples: WrapperModule = WrapperModule()

def _load_meta(self, repo_id: str, **kwargs) -> "Module":

Expand Down
59 changes: 49 additions & 10 deletions tests/test_vllm.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,23 @@ def test_multi_token_generation(vllm_gpt2, MSG_prompt: str):
assert vllm_gpt2.tokenizer.batch_decode([logit.argmax(dim=-1) for logit in logits]) == [" New", " York", " City"]


def test_sampling(vllm_gpt2, MSG_prompt: str):
with vllm_gpt2.trace(max_tokens=3) as tracer:
with tracer.invoke(MSG_prompt, temperature=0.0, top_p=1.0, max_tokens=3):
samples_1 = nnsight.list().save()
for ii in range(3):
samples_1.append(vllm_gpt2.samples.output)
vllm_gpt2.samples.next()
with tracer.invoke(MSG_prompt, temperature=0.8, top_p=0.95):
samples_2 = nnsight.list().save()
for ii in range(3):
samples_2.append(vllm_gpt2.samples.output)
vllm_gpt2.samples.next()

assert vllm_gpt2.tokenizer.batch_decode(samples_1) == [" New", " York", " City"]
assert vllm_gpt2.tokenizer.batch_decode(samples_2) == [" Richmond", " on", " the"]


""" def test_max_token_generation(vllm_gpt2, ET_prompt: str):
with vllm_gpt2.trace(ET_prompt, max_tokens=10):
logits = nnsight.list().save()
Expand Down Expand Up @@ -138,32 +155,54 @@ def test_batched_intervention(vllm_gpt2, ET_prompt: str,):


def test_batched_multi_token_generation(vllm_gpt2, ET_prompt: str, MSG_prompt: str):
max_token_1: int = 3
max_token_2: int = 5

num_prompts_1: int = 2
num_prompts_2: int = 1

with vllm_gpt2.trace() as tracer:
with tracer.invoke([MSG_prompt, ET_prompt], max_tokens=3):
with tracer.invoke([MSG_prompt, ET_prompt], max_tokens=max_token_1):
MSG_ET_hs = nnsight.list().save()
MSG_ET_logits = nnsight.list().save()
for ii in range(3):
MSG_ET_samples = nnsight.list().save()
for ii in range(max_token_1):
MSG_ET_hs.append(vllm_gpt2.transformer.h[5].output)
vllm_gpt2.transformer.h[5].next()
MSG_ET_logits.append(vllm_gpt2.logits.output)
vllm_gpt2.logits.next()
with tracer.invoke(MSG_prompt, max_tokens=5):
MSG_ET_samples.append(vllm_gpt2.samples.output)
vllm_gpt2.samples.next()
with tracer.invoke(MSG_prompt, max_tokens=max_token_2):
MSG_hs = nnsight.list().save()
MSG_logits = nnsight.list().save()
for ii in range(5):
MSG_samples = nnsight.list().save()
for ii in range(max_token_2):
MSG_hs.append(vllm_gpt2.transformer.h[5].output)
vllm_gpt2.transformer.h[5].next()
MSG_logits.append(vllm_gpt2.logits.output)
vllm_gpt2.logits.next()
MSG_samples.append(vllm_gpt2.samples.output)
vllm_gpt2.samples.next()

assert len(MSG_ET_hs) == 3
assert len(MSG_ET_hs) == max_token_1
assert all(hs.shape[0] for hs in MSG_ET_hs[1:])
assert len(MSG_ET_logits) == 3
assert all(logit.shape[0] == 2 for logit in MSG_ET_logits)
assert len(MSG_hs) == 5

assert len(MSG_ET_logits) == max_token_1
assert all(logit.shape[0] == num_prompts_1 for logit in MSG_ET_logits)

assert len(MSG_ET_samples) == max_token_1
assert all(sample.shape[0] == num_prompts_1 for sample in MSG_ET_samples)


assert len(MSG_hs) == max_token_2
assert all(hs.shape[0] for hs in MSG_hs[1:])
assert len(MSG_logits) == 5
assert all(logit.shape[0] == 1 for logit in MSG_logits)

assert len(MSG_logits) == max_token_2
assert all(logit.shape[0] == num_prompts_2 for logit in MSG_logits)

assert len(MSG_samples) == max_token_2
assert all(sample.shape[0] == num_prompts_2 for sample in MSG_samples)


""" def test_batched_multi_token_generation_with_iter(vllm_gpt2, ET_prompt: str, MSG_prompt: str):
Expand Down

0 comments on commit 9c8d7eb

Please sign in to comment.