Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Add HPU support to vLLM v1 - cont. #609

Open
wants to merge 44 commits into
base: habana_main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
44 commits
Select commit Hold shift + click to select a range
2191184
vLLM v1 HPU prototype
kzawora-intel Nov 12, 2024
fd77180
copy gpu model runner code, add hpugraphs support and profile run
kzawora-intel Nov 12, 2024
4dadef5
i am very much struggling
kzawora-intel Nov 13, 2024
9db1409
it's hopeless
kzawora-intel Nov 13, 2024
3b3098c
[wip] bypass prefill chunking in v1 scheduler
kzawora-intel Nov 14, 2024
c24adb5
colonoscopy
kzawora-intel Nov 14, 2024
2da069e
prefill runs, decode has deadlock, idk why
kzawora-intel Nov 14, 2024
932ce93
i'm done for today
kzawora-intel Nov 14, 2024
fc6a1c2
do better job at prefill chunking detection
kzawora-intel Nov 14, 2024
ff0ed54
mixed batch scheduling is still a problem
kzawora-intel Nov 15, 2024
50aa6b3
Merge remote-tracking branch 'origin/habana_main' into HEAD
kzawora-intel Nov 18, 2024
debec16
general hpu code rewrite
kzawora-intel Nov 18, 2024
0c1d0b6
add debug stuff, it seems like prefill is functional
kzawora-intel Nov 18, 2024
35d3e38
slight code cleanup
kzawora-intel Nov 18, 2024
491f991
remove garbage changes
kzawora-intel Nov 18, 2024
e29b84a
gsm8k now produces 69% acc on llama3.1
kzawora-intel Nov 19, 2024
27b4f32
add config not warmed up warnings
kzawora-intel Nov 19, 2024
087b5d2
add bucketinggit add -u .!
kzawora-intel Nov 19, 2024
6fdb6a9
llama3.1 now gives 81% in gsm8k without contiguous pa
kzawora-intel Nov 19, 2024
8714f9d
disable contiguous pa by default
kzawora-intel Nov 19, 2024
40ff0ac
async data copy
kzawora-intel Nov 19, 2024
28f2ac5
add split sampler optimization
kzawora-intel Nov 19, 2024
df7a1d4
add prompt batching
kzawora-intel Nov 20, 2024
623ed10
padded logits_indices and sampling + documentation
kzawora-intel Nov 20, 2024
0371c31
update docs
kzawora-intel Nov 20, 2024
d7b2a06
fix first-party random and greedy sampler for hpu
kzawora-intel Nov 20, 2024
c934e60
format.sh
kzawora-intel Nov 20, 2024
e0f4c26
add warmup w/ sampler (it doesn't work great tho)
kzawora-intel Nov 21, 2024
58c8f5d
add hpugraph check
kzawora-intel Nov 21, 2024
0c8b075
fix async engine, fix sampler corner cases
kzawora-intel Nov 22, 2024
fecedb5
Add padding-aware scheduling
kzawora-intel Nov 25, 2024
2ab1ac8
Merge remote-tracking branch 'origin/habana_main' into HEAD
kzawora-intel Nov 25, 2024
0d41073
bucketing refactor, enable contiguous pa, defrag blocks
kzawora-intel Nov 26, 2024
5645523
FreeKVCacheBlockHeapQueue bugfixes
kzawora-intel Nov 26, 2024
fd62723
[wip] add prefix caching support (it was actually really hard)
kzawora-intel Nov 26, 2024
e80f2be
fix hpugraphs and long seq corner case
kzawora-intel Dec 4, 2024
63af84c
Merge remote-tracking branch 'origin/private/kzawora/dec_10_rebase' i…
kzawora-intel Dec 10, 2024
061f037
fix uniproc executor
kzawora-intel Dec 10, 2024
8a28658
add multiproc hpu executor
kzawora-intel Dec 10, 2024
a6db04e
format.sh
kzawora-intel Dec 10, 2024
7955a5d
Merge remote-tracking branch 'origin/private/kzawora/dec_10_rebase' i…
kzawora-intel Dec 10, 2024
64e5dd5
Merge remote-tracking branch 'origin/habana_main' into private/kzawor…
kzawora-intel Dec 10, 2024
6d598d7
Merge remote-tracking branch 'origin/habana_main' into HEAD
kzawora-intel Dec 11, 2024
897b042
whoopsie i forgot to add executors
kzawora-intel Dec 12, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions vllm/attention/selector.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,10 @@ def _cached_get_attn_backend(
logger.info("Using HPUAttention backend.")
from vllm.attention.backends.hpu_attn import HPUAttentionBackend
return HPUAttentionBackend
elif backend == _Backend.HPU_ATTN_V1:
logger.info("Using HPUAttentionV1 backend.")
from vllm.v1.attention.backends.hpu_attn import HPUAttentionBackendV1
return HPUAttentionBackendV1
elif backend == _Backend.PALLAS:
logger.info("Using Pallas backend.")
from vllm.attention.backends.pallas import PallasAttentionBackend
Expand Down
3 changes: 2 additions & 1 deletion vllm/compilation/backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import vllm.envs as envs
from vllm.config import CompilationConfig
from vllm.logger import init_logger
from vllm.platforms import current_platform
from vllm.utils import weak_ref_tensors

from .counter import compilation_counter
Expand Down Expand Up @@ -230,7 +231,7 @@ def __init__(
compilation_configs: CompilationConfig,
):
global global_graph_pool
if global_graph_pool is None:
if global_graph_pool is None and current_platform.is_cuda_alike():
global_graph_pool = torch.cuda.graph_pool_handle()

# TODO: in the future, if we want to use multiple
Expand Down
1 change: 1 addition & 0 deletions vllm/distributed/parallel_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -959,6 +959,7 @@ def init_distributed_environment(
"world_size=%d rank=%d local_rank=%d "
"distributed_init_method=%s backend=%s", world_size, rank, local_rank,
distributed_init_method, backend)
print(distributed_init_method)
if not torch.distributed.is_initialized():
assert distributed_init_method is not None, (
"distributed_init_method must be provided when initializing "
Expand Down
13 changes: 12 additions & 1 deletion vllm/platforms/hpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

import torch

from vllm import envs

from .interface import Platform, PlatformEnum, _Backend

if TYPE_CHECKING:
Expand Down Expand Up @@ -36,12 +38,21 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
parallel_config = vllm_config.parallel_config
if parallel_config.worker_cls == "auto":
if scheduler_config.is_multi_step:
if envs.VLLM_USE_V1:
raise NotImplementedError
parallel_config.worker_cls = \
"vllm.worker.multi_step_hpu_worker.MultiStepHPUWorker"
elif vllm_config.speculative_config:
if envs.VLLM_USE_V1:
raise NotImplementedError
parallel_config.worker_cls = \
"vllm.spec_decode.spec_decode_worker.create_spec_worker"
parallel_config.sd_worker_cls = \
"vllm.worker.hpu_worker.HPUWorker"
else:
parallel_config.worker_cls = "vllm.worker.hpu_worker.HPUWorker"
if envs.VLLM_USE_V1:
parallel_config.worker_cls = \
"vllm.v1.worker.hpu_worker.HPUWorker"
else:
parallel_config.worker_cls = \
"vllm.worker.hpu_worker.HPUWorker"
Loading
Loading