Skip to content

Commit

Permalink
[New Feature][Habana main] spec decode PR2 - Medusa, MLP, Eagle (#461)
Browse files Browse the repository at this point in the history
Spec Decoder PR2 - enable Medusa, MLP

This PR is add on to #375
=> Do not merge until PR375 merged
  • Loading branch information
xuechendi authored Nov 12, 2024
1 parent 1565944 commit 890b1f0
Show file tree
Hide file tree
Showing 11 changed files with 313 additions and 62 deletions.
68 changes: 68 additions & 0 deletions examples/offline_inference_eaglespeculator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
import gc
import time
from typing import List

from vllm import LLM, SamplingParams


def time_generation(llm: LLM, prompts: List[str],
sampling_params: SamplingParams):
# Generate texts from the prompts. The output is a list of RequestOutput
# objects that contain the prompt, generated text, and other information.
# Warmup first
llm.generate(prompts, sampling_params)
llm.generate(prompts, sampling_params)
start = time.time()
outputs = llm.generate(prompts, sampling_params)
end = time.time()
latency_per_token = (end - start) / sum(
[len(o.outputs[0].token_ids) for o in outputs])
# Print the outputs.
ret = []
for output in outputs:
generated_text = output.outputs[0].text
ret.append(generated_text)
return ret, latency_per_token


if __name__ == "__main__":

prompts = [
"The future of AI is",
]
sampling_params = SamplingParams(temperature=0.8,
top_p=0.95,
max_tokens=20)

# Create an LLM without spec decoding
print("==============Without speculation==================")
llm = LLM(model="JackFram/llama-68m")

ret_non_spec, latency_per_token_non_spec = time_generation(
llm, prompts, sampling_params)

del llm
gc.collect()

# Create an LLM with spec decoding
print("==============With speculation=====================")
llm = LLM(
model="JackFram/llama-68m",
speculative_model="abhigoyal/vllm-eagle-llama-68m-random",
num_speculative_tokens=5,
# These are currently required for MLPSpeculator decoding
use_v2_block_manager=True,
)

ret_spec, latency_per_token_spec = time_generation(llm, prompts,
sampling_params)

del llm
gc.collect()
print("================= Summary =====================")
print("input is ", prompts, "\n")
print("Non Spec Decode - latency_per_token is ",
latency_per_token_non_spec)
print("Generated Text is :", ret_non_spec, "\n")
print("Spec Decode - latency_per_token is ", latency_per_token_spec)
print("Generated Text is :", ret_spec)
67 changes: 67 additions & 0 deletions examples/offline_inference_medusaspeculator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
import gc
import time
from typing import List

from vllm import LLM, SamplingParams


def time_generation(llm: LLM, prompts: List[str],
sampling_params: SamplingParams):
# Generate texts from the prompts. The output is a list of RequestOutput
# objects that contain the prompt, generated text, and other information.
# Warmup first
llm.generate(prompts, sampling_params)
llm.generate(prompts, sampling_params)
start = time.time()
outputs = llm.generate(prompts, sampling_params)
end = time.time()
latency_per_token = (end - start) / sum(
[len(o.outputs[0].token_ids) for o in outputs])
# Print the outputs.
ret = []
for output in outputs:
generated_text = output.outputs[0].text
ret.append(generated_text)
return ret, latency_per_token


if __name__ == "__main__":

prompts = [
"The future of AI is",
]
sampling_params = SamplingParams(temperature=0.8,
top_p=0.95,
max_tokens=20)

# Create an LLM without spec decoding
print("==============Without speculation==================")
llm = LLM(model="JackFram/llama-68m")

ret_non_spec, latency_per_token_non_spec = time_generation(
llm, prompts, sampling_params)

del llm
gc.collect()

# Create an LLM with spec decoding
print("==============With speculation=====================")
llm = LLM(
model="JackFram/llama-68m",
speculative_model="abhigoyal/vllm-medusa-llama-68m-random",
num_speculative_tokens=5,
use_v2_block_manager=True,
)

ret_spec, latency_per_token_spec = time_generation(llm, prompts,
sampling_params)

del llm
gc.collect()
print("================= Summary =====================")
print("input is ", prompts, "\n")
print("Non Spec Decode - latency_per_token is ",
latency_per_token_non_spec)
print("Generated Text is :", ret_non_spec, "\n")
print("Spec Decode - latency_per_token is ", latency_per_token_spec)
print("Generated Text is :", ret_spec)
1 change: 0 additions & 1 deletion vllm/executor/hpu_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,6 @@ def _get_worker_kwargs(
rank=rank,
distributed_init_method=distributed_init_method,
is_driver_worker=rank == 0,
speculative_config=self.speculative_config,
)

def _create_worker(self,
Expand Down
62 changes: 62 additions & 0 deletions vllm/spec_decode/hpu_draft_model_runner.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
from typing import List, Optional

import torch

from vllm.logger import init_logger
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.sequence import IntermediateTensors
from vllm.worker.hpu_model_runner import HPUModelRunner as ModelRunnerBaseCls
from vllm.worker.hpu_model_runner import ModelInputForHPUWithSamplingMetadata

logger = init_logger(__name__)

# A flag to enable debug prints for the updated input tensors
# before each step.
debug_advance_input = False
# A flag to allow GPU advance step for draft model runner.
# Set to False for debugging.
allow_gpu_advance_step = True


class HPUTP1DraftModelRunner(ModelRunnerBaseCls):
"""Specialized model runner for speculative decoding draft model.
Since the draft model always execute k forward passes consecutively to
generate k speculative tokens in a single speculative decoding step,
we could get rid of most CPU-GPU synchronization and data transfer
overheads by keeping model input and output tensors on GPU all the time.
TODOs:
1. Support TP > 1 (this requires some designs because we do not expect
any broadcasting inside execute_model).
"""

def __init__(self, *args, **kwargs):
if kwargs.get("return_hidden_states"):
raise ValueError(
"return_hidden_states is not supported for TP1DraftModelRunner."
)

super().__init__(*args, **kwargs)

self.indices_of_seq_with_bonus_tokens = None

@torch.inference_mode()
def execute_model(
self,
model_input: ModelInputForHPUWithSamplingMetadata,
kv_caches: List[torch.Tensor],
previous_hidden_states: Optional[torch.Tensor] = None,
intermediate_tensors: Optional[IntermediateTensors] = None,
num_steps: int = 1,
) -> Optional[List[SamplerOutput]]:
if previous_hidden_states is not None:
_, block_size = model_input.input_tokens.shape
previous_hidden_states = previous_hidden_states.expand(
block_size, -1).unsqueeze(0)
return super().execute_model(
model_input=model_input,
kv_caches=kv_caches,
previous_hidden_states=previous_hidden_states,
intermediate_tensors=intermediate_tensors,
num_steps=num_steps,
)
4 changes: 2 additions & 2 deletions vllm/spec_decode/medusa_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,10 @@
from vllm.spec_decode.interfaces import SpeculativeProposals
from vllm.spec_decode.proposer_worker_base import NonLLMProposerWorkerBase
from vllm.spec_decode.top1_proposer import Top1Proposer
from vllm.worker.worker import Worker
from vllm.worker.selector import WorkerCls


class MedusaWorker(NonLLMProposerWorkerBase, Worker):
class MedusaWorker(NonLLMProposerWorkerBase, WorkerCls):
"""Worker for Medusa.
"""

Expand Down
21 changes: 3 additions & 18 deletions vllm/spec_decode/multi_step_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,32 +5,17 @@
import torch

from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.platforms import current_platform
from vllm.sequence import (ExecuteModelRequest, HiddenStates, SequenceData,
SequenceGroupMetadata)
from vllm.spec_decode.draft_model_runner import TP1DraftModelRunner
from vllm.spec_decode.interfaces import (SpeculativeProposals,
SpeculativeProposer)
from vllm.spec_decode.proposer_worker_base import ProposerWorkerBase
from vllm.spec_decode.top1_proposer import Top1Proposer
from vllm.worker.selector import WorkerCls

if current_platform.is_neuron():
from vllm.worker.neuron_worker import NeuronWorker as WorkerBaseCls
elif current_platform.is_hpu():
from vllm.worker.hpu_worker import HPUWorker as WorkerBaseCls
elif current_platform.is_openvino:
from vllm.worker.openvino_worker import OpenVINOWorker as WorkerBaseCls
elif current_platform.is_cpu():
from vllm.worker.cpu_worker import CPUWorker as WorkerBaseCls
elif current_platform.is_tpu():
from vllm.worker.tpu_worker import TPUWorker as WorkerBaseCls
elif current_platform.is_xpu():
from vllm.worker.xpu_worker import XPUWorker as WorkerBaseCls
else:
from vllm.worker.worker import Worker as WorkerBaseCls


class MultiStepWorker(WorkerBaseCls, ProposerWorkerBase):

class MultiStepWorker(WorkerCls, ProposerWorkerBase):
"""The MultiStepWorker is equivalent to a Worker except that it allows
multiple forward passes in a single call, assuming the scheduler has
allocated enough space to store the additional KV. This reduces overhead
Expand Down
19 changes: 16 additions & 3 deletions vllm/spec_decode/spec_decode_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,12 @@
SpecDecodeBaseSampler, SpecDecodeStochasticBaseSampler)
from vllm.model_executor.layers.typical_acceptance_sampler import (
TypicalAcceptanceSampler)
from vllm.platforms import current_platform
from vllm.sequence import (VLLM_INVALID_TOKEN_ID,
CompletionSequenceGroupOutput, ExecuteModelRequest,
HiddenStates, SequenceGroupMetadata,
get_all_seq_ids_and_request_ids)
from vllm.spec_decode.batch_expansion import BatchExpansionTop1Scorer
from vllm.spec_decode.draft_model_runner import TP1DraftModelRunner
from vllm.spec_decode.interfaces import (SpeculativeProposals,
SpeculativeScorer, SpeculativeScores)
from vllm.spec_decode.medusa_worker import MedusaWorker
Expand All @@ -40,6 +40,11 @@
from vllm.worker.worker import Worker
from vllm.worker.worker_base import LoraNotSupportedWorkerBase, WorkerBase

if current_platform.is_hpu():
from vllm.spec_decode.hpu_draft_model_runner import HPUTP1DraftModelRunner
else:
from vllm.spec_decode.draft_model_runner import TP1DraftModelRunner

logger = init_logger(__name__)


Expand Down Expand Up @@ -159,8 +164,16 @@ def create_worker(
proposer_worker = MedusaWorker(**draft_worker_kwargs)
else:
if draft_tp == 1:
draft_worker_kwargs[
"model_runner_cls"] = TP1DraftModelRunner
if current_platform.is_cuda_alike():
draft_worker_kwargs[
"model_runner_cls"] = TP1DraftModelRunner
elif current_platform.is_hpu():
draft_worker_kwargs[
"model_runner_cls"] = HPUTP1DraftModelRunner
else:
raise NotImplementedError(
"DraftModelRunner not implemented for this platform"
)
else:
if draft_model_config.hf_config.model_type == "eagle":
raise NotImplementedError(
Expand Down
45 changes: 37 additions & 8 deletions vllm/spec_decode/target_model_runner.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,42 @@
from typing import List, Optional

from vllm.config import VllmConfig
from vllm.platforms import current_platform
from vllm.sequence import SequenceGroupMetadata
from vllm.worker.model_runner import (ModelInputForGPUWithSamplingMetadata,
ModelRunner)

if current_platform.is_cuda_alike():
from vllm.worker.model_runner import (
ModelInputForGPUWithSamplingMetadata as ModelInputCls) # yapf: disable
from vllm.worker.model_runner import ModelRunner as ModelRunnerCls
elif current_platform.is_neuron():
from vllm.worker.neuron_model_runner import (
ModelInputForNeuron as ModelInputCls) # yapf: disable
from vllm.worker.neuron_model_runner import (
NeuronModelRunner as ModelRunnerCls) # yapf: disable
elif current_platform.is_hpu():
from vllm.worker.hpu_model_runner import HPUModelRunner as ModelRunnerCls
from vllm.worker.hpu_model_runner import (
ModelInputForHPUWithSamplingMetadata as ModelInputCls) # yapf: disable
elif current_platform.is_openvino():
from vllm.worker.openvino_model_runner import ModelInput as ModelInputCls
from vllm.worker.openvino_model_runner import (
OpenVINOModelRunner as ModelRunnerCls) # yapf: disable
elif current_platform.is_cpu():
from vllm.worker.cpu_model_runner import CPUModelRunner as ModelRunnerCls
from vllm.worker.cpu_model_runner import (
ModelInputForCPUWithSamplingMetadata as ModelInputCls) # yapf: disable
elif current_platform.is_tpu():
from vllm.worker.tpu_model_runner import ModelInputForTPU as ModelInputCls
from vllm.worker.tpu_model_runner import TPUModelRunner as ModelRunnerCls
elif current_platform.is_xpu():
from vllm.worker.xpu_model_runner import (
ModelInputForXPUWithSamplingMetadata as ModelInputCls) # yapf: disable
from vllm.worker.xpu_model_runner import XPUModelRunner as ModelRunnerCls
else:
raise ValueError(f"Unsupported platform: {current_platform}")

class TargetModelRunner(ModelRunner):

class TargetModelRunner(ModelRunnerCls):
"""Specialized model runner for speculative decoding target model.
In speculative decoding, the log probabilities selected finally may not
be the same ones as selected by the target model sampling. This means
Expand Down Expand Up @@ -39,11 +69,10 @@ def prepare_model_input(
self,
seq_group_metadata_list: List[SequenceGroupMetadata],
virtual_engine: int = 0,
finished_requests_ids: Optional[List[str]] = None
) -> ModelInputForGPUWithSamplingMetadata:
model_input: ModelInputForGPUWithSamplingMetadata = super(
).prepare_model_input(seq_group_metadata_list, virtual_engine,
finished_requests_ids)
finished_requests_ids: Optional[List[str]] = None,
) -> ModelInputCls:
model_input: ModelInputCls = super().prepare_model_input(
seq_group_metadata_list, virtual_engine, finished_requests_ids)
# If token log probabilities is disabled then skip generating sampler
# CPU output. We directly serialize the GPU sampled_token_id tensors
# as needed. If log probabilities is enabled then synchronize all the
Expand Down
Loading

0 comments on commit 890b1f0

Please sign in to comment.