Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft: Delayed prompts #659

Draft
wants to merge 2 commits into
base: habana_main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -1129,6 +1129,9 @@ class SchedulerConfig:
# tokens in prefill.
use_padding_aware_scheduling: bool = False

# If True, delayed sampling will be enabled for prompts.
enable_delayed_sampling: bool = False

def __post_init__(self) -> None:
if self.max_num_batched_tokens is None:
if self.enable_chunked_prefill:
Expand Down Expand Up @@ -1204,6 +1207,12 @@ def _verify_args(self) -> None:
raise ValueError("Padding-aware scheduling currently "
"does not work with chunked prefill ")

if self.enable_delayed_sampling and self.num_lookahead_slots != 1:
raise ValueError(
"num_lookahead_slots "
f"({self.num_lookahead_slots}) must be 1 for delayed sampling."
)

@property
def is_multi_step(self) -> bool:
return self.num_scheduler_steps > 1
Expand Down
9 changes: 9 additions & 0 deletions vllm/engine/arg_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,7 @@ class EngineArgs:

scheduler_delay_factor: float = 0.0
enable_chunked_prefill: Optional[bool] = None
enable_delayed_sampling: bool = False

guided_decoding_backend: str = 'xgrammar'
# Speculative decoding configuration.
Expand Down Expand Up @@ -714,6 +715,13 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
const="True",
help='If set, the prefill requests can be chunked based on the '
'max_num_batched_tokens.')
parser.add_argument(
'--enable-delayed-sampling',
action='store_true',
help='If set, the sampling will be delayed by 1 step. First '
'model request execution (prefill) will return an invalid token '
'id that will be discarded. Actual sampling of valid token ids '
'starts from second model execution.')

parser.add_argument(
'--speculative-model',
Expand Down Expand Up @@ -1178,6 +1186,7 @@ def create_engine_config(self,
enable_chunked_prefill=self.enable_chunked_prefill,
is_multimodal_model=model_config.is_multimodal_model,
preemption_mode=self.preemption_mode,
enable_delayed_sampling=self.enable_delayed_sampling,
num_scheduler_steps=self.num_scheduler_steps,
multi_step_stream_outputs=self.multi_step_stream_outputs,
send_delta_data=(envs.VLLM_USE_RAY_SPMD_WORKER
Expand Down
4 changes: 3 additions & 1 deletion vllm/engine/output_processor/interfaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,9 @@ def create_output_processor(
This returns a single-step output processor if num_lookahead_slots is
zero, else returns a multi-step output processor.
"""
if scheduler_config.num_lookahead_slots == 0:
if (scheduler_config.num_lookahead_slots == 0
or (scheduler_config.num_lookahead_slots == 1
and scheduler_config.enable_delayed_sampling)):
# Importing here to avoid cycle.
from vllm.engine.output_processor.single_step import (
SingleStepOutputProcessor)
Expand Down
4 changes: 3 additions & 1 deletion vllm/engine/output_processor/single_step.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,9 @@ def _process_sequence_group_outputs(self, seq_group: SequenceGroup,

sample = outputs.samples[0]
seq = seq_group.first_seq
if not is_async:
# -1 means the output token is not valid (eg. first token if
# delayed sampling is enabled).
if not is_async and sample != -1:
seq.append_token_id(sample.output_token, sample.logprobs)
if sampling_params.detokenize and self.detokenizer:
new_char_count = self.detokenizer.decode_sequence_inplace(
Expand Down
32 changes: 23 additions & 9 deletions vllm/model_executor/layers/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,7 @@
# containing the sampled token ids and probabilities. This is used by
# speculative decoding.
self.include_gpu_probs_tensor = False
self.sample_token_positions_only = False
self.should_modify_greedy_probs_inplace = False

def _init_sampling_tensors(
Expand Down Expand Up @@ -302,6 +303,7 @@
sampling_tensors,
include_gpu_probs_tensor=self.include_gpu_probs_tensor,
modify_greedy_probs=self._should_modify_greedy_probs_inplace,
token_positions_only=self.sample_token_positions_only,
)

if self.include_gpu_probs_tensor:
Expand Down Expand Up @@ -588,6 +590,7 @@
def _greedy_sample(
selected_seq_groups: List[SequenceGroupToSample],
samples: torch.Tensor,
token_positions_only: bool = False,
) -> SampleResultType:
"""Run greedy sampling on a given samples.

Expand All @@ -601,7 +604,9 @@
same as the length of selected_seq_groups. If the corresponding
seq_group has do_sample=False, tuple contains ([], [])
"""
samples_lst = samples.tolist()
print("DEBUG sync in greedy_sample")
if not token_positions_only:
samples_lst = samples.tolist()
sample_idx = 0
results: SampleResultType = []
for seq_group in selected_seq_groups:
Expand All @@ -614,7 +619,9 @@
assert num_parent_seqs == 1, (
"Greedy sampling should have only one seq.")
parent_ids = list(range(num_parent_seqs))
next_token_ids = [samples_lst[sample_idx]]
next_token_ids = [
sample_idx if token_positions_only else samples_lst[sample_idx]
]
results.append((next_token_ids, parent_ids))
sample_idx += num_parent_seqs
return results
Expand All @@ -637,7 +644,8 @@
seq_group has do_sample=False, tuple contains ([], [])
"""
# Find the maximum n value of the prompt phase requests.
random_samples = random_samples.cpu()
print("DEBUG sync in random_sample")
random_samples = random_samples.cpu() # zrobic zamiast tego torch.zeros, by nie kopiowac na cpu

Check failure on line 648 in vllm/model_executor/layers/sampler.py

View workflow job for this annotation

GitHub Actions / ruff (3.12)

Ruff (E501)

vllm/model_executor/layers/sampler.py:648:81: E501 Line too long (99 > 80)
sample_idx = 0
results: SampleResultType = []
for seq_group in selected_seq_groups:
Expand All @@ -653,7 +661,7 @@
# Prompt phase.
parent_ids = [0] * sampling_params.n
next_token_ids = random_samples[
sample_idx, :sampling_params.n].tolist()
sample_idx, :sampling_params.n].tolist() # to tez host operacja, czy to potrzebne?

Check failure on line 664 in vllm/model_executor/layers/sampler.py

View workflow job for this annotation

GitHub Actions / ruff (3.12)

Ruff (E501)

vllm/model_executor/layers/sampler.py:664:81: E501 Line too long (98 > 80)
else:
# Generation phase.
parent_ids = list(range(num_parent_seqs))
Expand Down Expand Up @@ -798,7 +806,8 @@


def get_pythonized_sample_results(
sample_result_args: SampleResultArgsType) -> SampleResultType:
sample_result_args: SampleResultArgsType,
token_positions_only: bool = False) -> SampleResultType:
'''This function consumes GPU-side sampler results and computes
Pythonized CPU-side sampler results (GPU -> CPU sync.)

Expand Down Expand Up @@ -836,7 +845,8 @@
continue
(seq_group_id, seq_groups) = sample_metadata[sampling_type]
if sampling_type == SamplingType.GREEDY:
sample_results = _greedy_sample(seq_groups, greedy_samples)
sample_results = _greedy_sample(seq_groups, greedy_samples,
token_positions_only)
elif sampling_type in (SamplingType.RANDOM, SamplingType.RANDOM_SEED):
sample_results = _random_sample(seq_groups,
multinomial_samples[sampling_type])
Expand All @@ -858,6 +868,7 @@
sampling_tensors: SamplingTensors,
include_gpu_probs_tensor: bool,
modify_greedy_probs: bool,
token_positions_only: bool = False,
) -> SampleReturnType:
'''Torch-oriented _sample() implementation.

Expand Down Expand Up @@ -967,13 +978,14 @@
greedy_samples=greedy_samples,
beam_search_logprobs=beam_search_logprobs,
sample_results_dict=sample_results_dict)

#import pdb; pdb.set_trace()
if not sampling_metadata.skip_sampler_cpu_output:
# GPU<->CPU sync happens here.
# GPU<->CPU sync happens here,
# unless we're storing only token positions (token_positions_only=True).
# This also converts the sampler output to a Python object.
# Return Pythonized sampler result & sampled token ids
return get_pythonized_sample_results(
maybe_deferred_args), sampled_token_ids_tensor
maybe_deferred_args, token_positions_only), sampled_token_ids_tensor
else:
# Defer sampler result Pythonization; return deferred
# Pythonization args & sampled token ids
Expand All @@ -990,6 +1002,7 @@
sampling_tensors: SamplingTensors,
include_gpu_probs_tensor: bool,
modify_greedy_probs: bool,
token_positions_only: bool,
) -> SampleReturnType:
"""
Args:
Expand All @@ -1010,6 +1023,7 @@
sampling_tensors,
include_gpu_probs_tensor=include_gpu_probs_tensor,
modify_greedy_probs=modify_greedy_probs,
token_positions_only=token_positions_only,
)


Expand Down
15 changes: 9 additions & 6 deletions vllm/sequence.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,10 @@ class SequenceData(msgspec.Struct,
# It is used to compute mrope_position_ids.
_mrope_position_delta: Optional[int] = None

# it is used in delayed prompt sampling
prev_logits = None
prev_logits_idx = None

@staticmethod
def from_prompt_token_counts(
*token_counts: Tuple[int, int]) -> "SequenceData":
Expand Down Expand Up @@ -279,11 +283,11 @@ def mrope_position_delta(self) -> Optional[int]:
def mrope_position_delta(self, new_mrope_position_delta):
self._mrope_position_delta = new_mrope_position_delta

def append_token_id(self, token_id: int, logprob: float) -> None:
def append_token_id(self, token_id: int, logprob: Optional[float]) -> None:
self._output_token_ids.append(token_id)
self._new_appended_tokens.append(token_id)
self._cached_all_token_ids.append(token_id)
self._cumulative_logprob += logprob
self._cumulative_logprob += logprob if logprob is not None else 0.0

def get_len(self) -> int:
return len(self._output_token_ids) + len(self._prompt_token_ids)
Expand Down Expand Up @@ -315,8 +319,6 @@ def get_num_computed_tokens(self) -> int:
def update_num_computed_tokens(self, num_new_computed_tokens: int):
"""Update number of tokens computed so far."""
self._num_computed_tokens += num_new_computed_tokens
assert self._num_computed_tokens <= self.get_len(), (
self._num_computed_tokens, self.get_len())
# If all tokens are computed, it means it is in decoding phase.
if self.get_num_uncomputed_tokens() == 0:
self._stage = SequenceStage.DECODE
Expand Down Expand Up @@ -536,9 +538,10 @@ def reset_state_for_recompute(self):

def append_token_id(self, token_id: int, logprobs: Dict[int,
Logprob]) -> None:
assert token_id in logprobs
self.output_logprobs.append(logprobs)
self.data.append_token_id(token_id, logprobs[token_id].logprob)
self.data.append_token_id(
token_id,
logprobs[token_id].logprob if token_id in logprobs else None)

def get_len(self) -> int:
return self.data.get_len()
Expand Down
Loading
Loading