Skip to content

Commit

Permalink
DRAFT: start of implementation of delayed prompts
Browse files Browse the repository at this point in the history
First draft of what where to change, not functional yet.
Comments in execute_model mark ideas what and where should be changed
  • Loading branch information
kamil-kaczor committed Dec 20, 2024
1 parent da61ecf commit 11a87dc
Show file tree
Hide file tree
Showing 7 changed files with 221 additions and 40 deletions.
9 changes: 9 additions & 0 deletions vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -1129,6 +1129,9 @@ class SchedulerConfig:
# tokens in prefill.
use_padding_aware_scheduling: bool = False

# If True, delayed sampling will be enabled for prompts.
enable_delayed_sampling: bool = False

def __post_init__(self) -> None:
if self.max_num_batched_tokens is None:
if self.enable_chunked_prefill:
Expand Down Expand Up @@ -1204,6 +1207,12 @@ def _verify_args(self) -> None:
raise ValueError("Padding-aware scheduling currently "
"does not work with chunked prefill ")

if self.enable_delayed_sampling and self.num_lookahead_slots != 1:
raise ValueError(
"num_lookahead_slots "
f"({self.num_lookahead_slots}) must be 1 for delayed sampling."
)

@property
def is_multi_step(self) -> bool:
return self.num_scheduler_steps > 1
Expand Down
9 changes: 9 additions & 0 deletions vllm/engine/arg_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,7 @@ class EngineArgs:

scheduler_delay_factor: float = 0.0
enable_chunked_prefill: Optional[bool] = None
enable_delayed_sampling: bool = False

guided_decoding_backend: str = 'xgrammar'
# Speculative decoding configuration.
Expand Down Expand Up @@ -714,6 +715,13 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
const="True",
help='If set, the prefill requests can be chunked based on the '
'max_num_batched_tokens.')
parser.add_argument(
'--enable-delayed-sampling',
action='store_true',
help='If set, the sampling will be delayed by 1 step. First '
'model request execution (prefill) will return an invalid token '
'id that will be discarded. Actual sampling of valid token ids '
'starts from second model execution.')

parser.add_argument(
'--speculative-model',
Expand Down Expand Up @@ -1178,6 +1186,7 @@ def create_engine_config(self,
enable_chunked_prefill=self.enable_chunked_prefill,
is_multimodal_model=model_config.is_multimodal_model,
preemption_mode=self.preemption_mode,
enable_delayed_sampling=self.enable_delayed_sampling,
num_scheduler_steps=self.num_scheduler_steps,
multi_step_stream_outputs=self.multi_step_stream_outputs,
send_delta_data=(envs.VLLM_USE_RAY_SPMD_WORKER
Expand Down
4 changes: 3 additions & 1 deletion vllm/engine/output_processor/interfaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,9 @@ def create_output_processor(
This returns a single-step output processor if num_lookahead_slots is
zero, else returns a multi-step output processor.
"""
if scheduler_config.num_lookahead_slots == 0:
if (scheduler_config.num_lookahead_slots == 0
or (scheduler_config.num_lookahead_slots == 1
and scheduler_config.enable_delayed_sampling)):
# Importing here to avoid cycle.
from vllm.engine.output_processor.single_step import (
SingleStepOutputProcessor)
Expand Down
4 changes: 3 additions & 1 deletion vllm/engine/output_processor/single_step.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,9 @@ def _process_sequence_group_outputs(self, seq_group: SequenceGroup,

sample = outputs.samples[0]
seq = seq_group.first_seq
if not is_async:
# -1 means the output token is not valid (eg. first token if
# delayed sampling is enabled).
if not is_async and sample != -1:
seq.append_token_id(sample.output_token, sample.logprobs)
if sampling_params.detokenize and self.detokenizer:
new_char_count = self.detokenizer.decode_sequence_inplace(
Expand Down
32 changes: 23 additions & 9 deletions vllm/model_executor/layers/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,7 @@ def __init__(self):
# containing the sampled token ids and probabilities. This is used by
# speculative decoding.
self.include_gpu_probs_tensor = False
self.sample_token_positions_only = False
self.should_modify_greedy_probs_inplace = False

def _init_sampling_tensors(
Expand Down Expand Up @@ -302,6 +303,7 @@ def forward(
sampling_tensors,
include_gpu_probs_tensor=self.include_gpu_probs_tensor,
modify_greedy_probs=self._should_modify_greedy_probs_inplace,
token_positions_only=self.sample_token_positions_only,
)

if self.include_gpu_probs_tensor:
Expand Down Expand Up @@ -588,6 +590,7 @@ def _apply_min_p(
def _greedy_sample(
selected_seq_groups: List[SequenceGroupToSample],
samples: torch.Tensor,
token_positions_only: bool = False,
) -> SampleResultType:
"""Run greedy sampling on a given samples.
Expand All @@ -601,7 +604,9 @@ def _greedy_sample(
same as the length of selected_seq_groups. If the corresponding
seq_group has do_sample=False, tuple contains ([], [])
"""
samples_lst = samples.tolist()
print("DEBUG sync in greedy_sample")
if not token_positions_only:
samples_lst = samples.tolist()
sample_idx = 0
results: SampleResultType = []
for seq_group in selected_seq_groups:
Expand All @@ -614,7 +619,9 @@ def _greedy_sample(
assert num_parent_seqs == 1, (
"Greedy sampling should have only one seq.")
parent_ids = list(range(num_parent_seqs))
next_token_ids = [samples_lst[sample_idx]]
next_token_ids = [
sample_idx if token_positions_only else samples_lst[sample_idx]
]
results.append((next_token_ids, parent_ids))
sample_idx += num_parent_seqs
return results
Expand All @@ -637,7 +644,8 @@ def _random_sample(
seq_group has do_sample=False, tuple contains ([], [])
"""
# Find the maximum n value of the prompt phase requests.
random_samples = random_samples.cpu()
print("DEBUG sync in random_sample")
random_samples = random_samples.cpu() # zrobic zamiast tego torch.zeros, by nie kopiowac na cpu
sample_idx = 0
results: SampleResultType = []
for seq_group in selected_seq_groups:
Expand All @@ -653,7 +661,7 @@ def _random_sample(
# Prompt phase.
parent_ids = [0] * sampling_params.n
next_token_ids = random_samples[
sample_idx, :sampling_params.n].tolist()
sample_idx, :sampling_params.n].tolist() # to tez host operacja, czy to potrzebne?
else:
# Generation phase.
parent_ids = list(range(num_parent_seqs))
Expand Down Expand Up @@ -798,7 +806,8 @@ def _top_k_top_p_multinomial_with_flashinfer(


def get_pythonized_sample_results(
sample_result_args: SampleResultArgsType) -> SampleResultType:
sample_result_args: SampleResultArgsType,
token_positions_only: bool = False) -> SampleResultType:
'''This function consumes GPU-side sampler results and computes
Pythonized CPU-side sampler results (GPU -> CPU sync.)
Expand Down Expand Up @@ -836,7 +845,8 @@ def get_pythonized_sample_results(
continue
(seq_group_id, seq_groups) = sample_metadata[sampling_type]
if sampling_type == SamplingType.GREEDY:
sample_results = _greedy_sample(seq_groups, greedy_samples)
sample_results = _greedy_sample(seq_groups, greedy_samples,
token_positions_only)
elif sampling_type in (SamplingType.RANDOM, SamplingType.RANDOM_SEED):
sample_results = _random_sample(seq_groups,
multinomial_samples[sampling_type])
Expand All @@ -858,6 +868,7 @@ def _sample_with_torch(
sampling_tensors: SamplingTensors,
include_gpu_probs_tensor: bool,
modify_greedy_probs: bool,
token_positions_only: bool = False,
) -> SampleReturnType:
'''Torch-oriented _sample() implementation.
Expand Down Expand Up @@ -967,13 +978,14 @@ def _sample_with_torch(
greedy_samples=greedy_samples,
beam_search_logprobs=beam_search_logprobs,
sample_results_dict=sample_results_dict)

#import pdb; pdb.set_trace()
if not sampling_metadata.skip_sampler_cpu_output:
# GPU<->CPU sync happens here.
# GPU<->CPU sync happens here,
# unless we're storing only token positions (token_positions_only=True).
# This also converts the sampler output to a Python object.
# Return Pythonized sampler result & sampled token ids
return get_pythonized_sample_results(
maybe_deferred_args), sampled_token_ids_tensor
maybe_deferred_args, token_positions_only), sampled_token_ids_tensor
else:
# Defer sampler result Pythonization; return deferred
# Pythonization args & sampled token ids
Expand All @@ -990,6 +1002,7 @@ def _sample(
sampling_tensors: SamplingTensors,
include_gpu_probs_tensor: bool,
modify_greedy_probs: bool,
token_positions_only: bool,
) -> SampleReturnType:
"""
Args:
Expand All @@ -1010,6 +1023,7 @@ def _sample(
sampling_tensors,
include_gpu_probs_tensor=include_gpu_probs_tensor,
modify_greedy_probs=modify_greedy_probs,
token_positions_only=token_positions_only,
)


Expand Down
15 changes: 9 additions & 6 deletions vllm/sequence.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,10 @@ class SequenceData(msgspec.Struct,
# It is used to compute mrope_position_ids.
_mrope_position_delta: Optional[int] = None

# it is used in delayed prompt sampling
prev_logits = None
prev_logits_idx = None

@staticmethod
def from_prompt_token_counts(
*token_counts: Tuple[int, int]) -> "SequenceData":
Expand Down Expand Up @@ -279,11 +283,11 @@ def mrope_position_delta(self) -> Optional[int]:
def mrope_position_delta(self, new_mrope_position_delta):
self._mrope_position_delta = new_mrope_position_delta

def append_token_id(self, token_id: int, logprob: float) -> None:
def append_token_id(self, token_id: int, logprob: Optional[float]) -> None:
self._output_token_ids.append(token_id)
self._new_appended_tokens.append(token_id)
self._cached_all_token_ids.append(token_id)
self._cumulative_logprob += logprob
self._cumulative_logprob += logprob if logprob is not None else 0.0

def get_len(self) -> int:
return len(self._output_token_ids) + len(self._prompt_token_ids)
Expand Down Expand Up @@ -315,8 +319,6 @@ def get_num_computed_tokens(self) -> int:
def update_num_computed_tokens(self, num_new_computed_tokens: int):
"""Update number of tokens computed so far."""
self._num_computed_tokens += num_new_computed_tokens
assert self._num_computed_tokens <= self.get_len(), (
self._num_computed_tokens, self.get_len())
# If all tokens are computed, it means it is in decoding phase.
if self.get_num_uncomputed_tokens() == 0:
self._stage = SequenceStage.DECODE
Expand Down Expand Up @@ -536,9 +538,10 @@ def reset_state_for_recompute(self):

def append_token_id(self, token_id: int, logprobs: Dict[int,
Logprob]) -> None:
assert token_id in logprobs
self.output_logprobs.append(logprobs)
self.data.append_token_id(token_id, logprobs[token_id].logprob)
self.data.append_token_id(
token_id,
logprobs[token_id].logprob if token_id in logprobs else None)

def get_len(self) -> int:
return self.data.get_len()
Expand Down
Loading

0 comments on commit 11a87dc

Please sign in to comment.