Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

VLLM Sampled Tokens #311

Merged
merged 4 commits into from
Jan 31, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 24 additions & 7 deletions src/nnsight/modeling/vllm/model_runners/GPUModelRunner.py
Original file line number Diff line number Diff line change
Expand Up @@ -292,12 +292,17 @@ def inner():
hidden_or_intermediate_states, model_input.sampling_metadata
)

with Patcher(
[
Patch(interleaver, [(idx, 1) for idx in range(len(interleaver.batch_groups))], "batch_groups"),
Patch(interleaver, len(interleaver.batch_groups), "batch_size")
]
):
# patching the batch_size to be the number of logits,
# since vLLM optimizes the inference by turning the size of the input to be of size power of 2.
patches = [Patch(interleaver, logits.shape[0], "batch_size")]

# `batch_groups` is adapted to the token positions of the flattened input during the first token generation iteration
# since the logit and sample tensors have different number of tokens,
# we need to patch `batch_groups` to reflect the correct batches specified by the invoker contexts defined by the user.
if model_input.sampling_metadata.seq_groups[0].is_prompt:
patches.append(Patch(interleaver, model_input.sampling_metadata.nns_batch_groups, "batch_groups"))

with Patcher(patches):
logits = self.model.logits(logits)

if not self.is_driver_worker:
Expand All @@ -311,7 +316,19 @@ def inner():
logits=logits,
sampling_metadata=model_input.sampling_metadata,
)
output.sampled_token_ids = self.model.tokens(output.sampled_token_ids)

og_sample_tokens = torch.tensor([token.samples[0].output_token for token in output.outputs])

with Patcher(patches):
sample_tokens = self.model.samples(og_sample_tokens)

# inject any changes to the sampled tokens
for idx, seq_out in enumerate(output.outputs):
sample = seq_out.samples[0]
sample.output_token = sample_tokens[idx].item()
logprob = sample.logprobs.pop(og_sample_tokens[idx].item())
sample.logprobs[sample_tokens[idx].item()] = logprob

if (
self.observability_config is not None
and self.observability_config.collect_model_forward_time
Expand Down
14 changes: 13 additions & 1 deletion src/nnsight/modeling/vllm/sampling.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import copy
from typing import Any, Dict, List, Optional, Tuple, Union
from typing import Dict, List, Optional, Tuple

from nnsight.intervention.graph import InterventionGraph
import torch
Expand All @@ -17,6 +17,7 @@
class NNsightSamplingParams(SamplingParams):

intervention_graph: Optional[InterventionGraph] = None
nns_batch_groups: Optional[List[Tuple[int, int]]] = None
invoker_group: Optional[int] = None
is_default_param: bool = True

Expand All @@ -43,19 +44,22 @@ def clone(self) -> "SamplingParams":
class NNsightSamplingMetadata(SamplingMetadata):

intervention_graph: Optional[InterventionGraph] = None
nns_batch_groups: Optional[List[Tuple[int, int]]] = None
batch_groups: Optional[List[Tuple[int, int]]] = None

def __init__(
self,
*args,
intervention_graph: InterventionGraph = None,
nns_batch_groups: List[Tuple[int, int]] = None,
batch_groups: Dict[int, Tuple[int, int]] = None,
**kwargs,
):

super().__init__(*args, **kwargs)

self.intervention_graph = intervention_graph
self.nns_batch_groups = nns_batch_groups
self.batch_groups = batch_groups

@staticmethod
Expand Down Expand Up @@ -96,6 +100,7 @@ def prepare(
### NNSIGHT ###########################################

intervention_graphs = []
nns_batch_groups = []
batch_groups = []
batch_groups_offset = 0

Expand All @@ -106,13 +111,17 @@ def prepare(
seq_group_intervention_graph = (
seq_group.sampling_params.intervention_graph
)

seq_group_nns_batch_groups = seq_group.sampling_params.nns_batch_groups

if isinstance(seq_group_intervention_graph, InterventionGraph):

if seq_group_intervention_graph not in intervention_graphs:

intervention_graphs.append(seq_group_intervention_graph)

nns_batch_groups.append(seq_group_nns_batch_groups)

batch_groups_offset = len(batch_groups)

seq_group_batch_group = (
Expand All @@ -136,8 +145,10 @@ def prepare(

if n_graphs== 0:
intervention_graph = None
nns_batch_groups = None
elif n_graphs == 1:
intervention_graph =intervention_graphs[0]
nns_batch_groups = nns_batch_groups[0]

""" else:
intervention_graph = MultiGraph(intervention_graphs.values())
Expand All @@ -152,6 +163,7 @@ def prepare(
categorized_sample_indices=categorized_sample_indices,
num_prompts=num_prompts,
intervention_graph=intervention_graph,
nns_batch_groups = nns_batch_groups,
batch_groups=batch_groups,
)

Expand Down
5 changes: 3 additions & 2 deletions src/nnsight/modeling/vllm/vllm.py
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ class VLLM(RemoteableMixin):
- vllm_entrypoint (vllm.LLM): vLLM language model.
- tokenizer (vllm.transformers_utils.tokenizer.AnyTokenizer): tokenizer.
- logits (nnsight.WrapperModule): logits.
- tokens (nnsight.WrapperModule): tokens.
- samples (nnsight.WrapperModule): sampled tokens.

.. code-block:: python
from nnsight.models.VLLM import VLLM
Expand Down Expand Up @@ -68,7 +68,7 @@ def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)

self.logits: Envoy = WrapperModule()
self.tokens: Envoy = WrapperModule()
self.samples: Envoy = WrapperModule()

def _load_meta(self, repo_id: str, **kwargs) -> "Module":

Expand Down Expand Up @@ -242,6 +242,7 @@ def interleave(
for param in args[1]:

param.intervention_graph = interleaver.graph
param.nns_batch_groups = interleaver.batch_groups

if fn is None:
fn = self._execute
Expand Down
67 changes: 59 additions & 8 deletions tests/test_vllm.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,23 @@ def test_multi_token_generation(vllm_gpt2, MSG_prompt: str):
assert vllm_gpt2.tokenizer.batch_decode([logit.argmax(dim=-1) for logit in logits]) == [" New", " York", " City"]


def test_sampling(vllm_gpt2, MSG_prompt: str):
with vllm_gpt2.trace(max_tokens=3) as tracer:
with tracer.invoke(MSG_prompt, temperature=0.0, top_p=1.0, max_tokens=3):
samples_1 = nnsight.list().save()
for ii in range(3):
samples_1.append(vllm_gpt2.samples.output)
vllm_gpt2.samples.next()
with tracer.invoke(MSG_prompt, temperature=0.8, top_p=0.95):
samples_2 = nnsight.list().save()
for ii in range(3):
samples_2.append(vllm_gpt2.samples.output)
vllm_gpt2.samples.next()

assert vllm_gpt2.tokenizer.batch_decode(samples_1) == [" New", " York", " City"]
assert vllm_gpt2.tokenizer.batch_decode(samples_2) == [" Richmond", " on", " the"]


""" def test_max_token_generation(vllm_gpt2, ET_prompt: str):
with vllm_gpt2.trace(ET_prompt, max_tokens=10):
logits = nnsight.list().save()
Expand Down Expand Up @@ -138,20 +155,54 @@ def test_batched_intervention(vllm_gpt2, ET_prompt: str,):


def test_batched_multi_token_generation(vllm_gpt2, ET_prompt: str, MSG_prompt: str):
max_token_1: int = 3
max_token_2: int = 5

num_prompts_1: int = 2
num_prompts_2: int = 1

with vllm_gpt2.trace() as tracer:
with tracer.invoke(ET_prompt, max_tokens=3):
ET_logits = nnsight.list().save()
for ii in range(3):
ET_logits.append(vllm_gpt2.logits.output)
with tracer.invoke([MSG_prompt, ET_prompt], max_tokens=max_token_1):
MSG_ET_hs = nnsight.list().save()
MSG_ET_logits = nnsight.list().save()
MSG_ET_samples = nnsight.list().save()
for ii in range(max_token_1):
MSG_ET_hs.append(vllm_gpt2.transformer.h[5].output)
vllm_gpt2.transformer.h[5].next()
MSG_ET_logits.append(vllm_gpt2.logits.output)
vllm_gpt2.logits.next()
with tracer.invoke(MSG_prompt, max_tokens=5):
MSG_ET_samples.append(vllm_gpt2.samples.output)
vllm_gpt2.samples.next()
with tracer.invoke(MSG_prompt, max_tokens=max_token_2):
MSG_hs = nnsight.list().save()
MSG_logits = nnsight.list().save()
for ii in range(5):
MSG_samples = nnsight.list().save()
for ii in range(max_token_2):
MSG_hs.append(vllm_gpt2.transformer.h[5].output)
vllm_gpt2.transformer.h[5].next()
MSG_logits.append(vllm_gpt2.logits.output)
vllm_gpt2.logits.next()
MSG_samples.append(vllm_gpt2.samples.output)
vllm_gpt2.samples.next()

assert len(MSG_ET_hs) == max_token_1
assert all(hs.shape[0] == num_prompts_1 for hs in MSG_ET_hs[1:])

assert len(MSG_ET_logits) == max_token_1
assert all(logit.shape[0] == num_prompts_1 for logit in MSG_ET_logits)

assert len(MSG_ET_samples) == max_token_1
assert all(sample.shape[0] == num_prompts_1 for sample in MSG_ET_samples)


assert len(MSG_hs) == max_token_2
assert all(hs.shape[0] == num_prompts_2 for hs in MSG_hs[1:])

assert len(MSG_logits) == max_token_2
assert all(logit.shape[0] == num_prompts_2 for logit in MSG_logits)

assert len(ET_logits) == 3
assert len(MSG_logits) == 5
assert len(MSG_samples) == max_token_2
assert all(sample.shape[0] == num_prompts_2 for sample in MSG_samples)


""" def test_batched_multi_token_generation_with_iter(vllm_gpt2, ET_prompt: str, MSG_prompt: str):
Expand Down