diff --git a/Dockerfile.hpu b/Dockerfile.hpu new file mode 100644 index 0000000000000..f481c8c6a57bf --- /dev/null +++ b/Dockerfile.hpu @@ -0,0 +1,16 @@ +FROM vault.habana.ai/gaudi-docker/1.18.0/ubuntu22.04/habanalabs/pytorch-installer-2.4.0:latest + +COPY ./ /workspace/vllm + +WORKDIR /workspace/vllm + +RUN pip install -v -r requirements-hpu.txt + +ENV no_proxy=localhost,127.0.0.1 +ENV PT_HPU_ENABLE_LAZY_COLLECTIVES=true + +RUN VLLM_TARGET_DEVICE=hpu python3 setup.py install + +WORKDIR /workspace/ + +ENTRYPOINT ["python3", "-m", "vllm.entrypoints.openai.api_server"] diff --git a/docs/source/getting_started/gaudi-installation.rst b/docs/source/getting_started/gaudi-installation.rst new file mode 100644 index 0000000000000..68c1a56660fa4 --- /dev/null +++ b/docs/source/getting_started/gaudi-installation.rst @@ -0,0 +1,402 @@ +Installation with Intel® Gaudi® AI Accelerators +=============================================== + +This README provides instructions on running vLLM with Intel Gaudi devices. + +Requirements and Installation +============================= + +Please follow the instructions provided in the `Gaudi Installation +Guide `__ +to set up the execution environment. To achieve the best performance, +please follow the methods outlined in the `Optimizing Training Platform +Guide `__. + +Requirements +------------ + +- OS: Ubuntu 22.04 LTS +- Python: 3.10 +- Intel Gaudi accelerator +- Intel Gaudi software version 1.18.0 + + +Quick start using Dockerfile +---------------------------- +.. code:: console + + $ docker build -f Dockerfile.hpu -t vllm-hpu-env . + $ docker run -it --runtime=habana -e HABANA_VISIBLE_DEVICES=all -e OMPI_MCA_btl_vader_single_copy_mechanism=none --cap-add=sys_nice --net=host --rm vllm-hpu-env + + +.. tip:: + If you're observing the following error: ``docker: Error response from daemon: Unknown runtime specified habana.``, please refer to "Install Using Containers" section of `Intel Gaudi Software Stack and Driver Installation `__. Make sure you have ``habana-container-runtime`` package installed and that ``habana`` container runtime is registered. + + +Build from source +----------------- + +Environment verification +~~~~~~~~~~~~~~~~~~~~~~~~ + +To verify that the Intel Gaudi software was correctly installed, run: + +.. code:: console + + $ hl-smi # verify that hl-smi is in your PATH and each Gaudi accelerator is visible + $ apt list --installed | grep habana # verify that habanalabs-firmware-tools, habanalabs-graph, habanalabs-rdma-core, habanalabs-thunk and habanalabs-container-runtime are installed + $ pip list | grep habana # verify that habana-torch-plugin, habana-torch-dataloader, habana-pyhlml and habana-media-loader are installed + $ pip list | grep neural # verify that neural_compressor is installed + +Refer to `Intel Gaudi Software Stack +Verification `__ +for more details. + +Run Docker Image +~~~~~~~~~~~~~~~~ + +It is highly recommended to use the latest Docker image from Intel Gaudi +vault. Refer to the `Intel Gaudi +documentation `__ +for more details. + +Use the following commands to run a Docker image: + +.. code:: console + + $ docker pull vault.habana.ai/gaudi-docker/1.18.0/ubuntu22.04/habanalabs/pytorch-installer-2.4.0:latest + $ docker run -it --runtime=habana -e HABANA_VISIBLE_DEVICES=all -e OMPI_MCA_btl_vader_single_copy_mechanism=none --cap-add=sys_nice --net=host --ipc=host vault.habana.ai/gaudi-docker/1.18.0/ubuntu22.04/habanalabs/pytorch-installer-2.4.0:latest + +Build and Install vLLM +~~~~~~~~~~~~~~~~~~~~~~ + +To build and install vLLM from source, run: + +.. code:: console + + $ git clone https://github.com/vllm-project/vllm.git + $ cd vllm + $ python setup.py develop + + +Currently, the latest features and performance optimizations are developed in Gaudi's `vLLM-fork `__ and we periodically upstream them to vLLM main repo. To install latest `HabanaAI/vLLM-fork `__, run the following: + +.. code:: console + + $ git clone https://github.com/HabanaAI/vllm-fork.git + $ cd vllm-fork + $ git checkout habana_main + $ python setup.py develop + + +Supported Features +================== + +- `Offline batched + inference `__ +- Online inference via `OpenAI-Compatible + Server `__ +- HPU autodetection - no need to manually select device within vLLM +- Paged KV cache with algorithms enabled for Intel Gaudi accelerators +- Custom Intel Gaudi implementations of Paged Attention, KV cache ops, + prefill attention, Root Mean Square Layer Normalization, Rotary + Positional Encoding +- Tensor parallelism support for multi-card inference +- Inference with `HPU Graphs `__ + for accelerating low-batch latency and throughput +- Attention with Linear Biases (ALiBi) + +Unsupported Features +==================== + +- Beam search +- LoRA adapters +- Quantization +- Prefill chunking (mixed-batch inferencing) + +Supported Configurations +======================== + +The following configurations have been validated to be function with +Gaudi2 devices. Configurations that are not listed may or may not work. + +- `meta-llama/Llama-2-7b `__ + on single HPU, or with tensor parallelism on 2x and 8x HPU, BF16 + datatype with random or greedy sampling +- `meta-llama/Llama-2-7b-chat-hf `__ + on single HPU, or with tensor parallelism on 2x and 8x HPU, BF16 + datatype with random or greedy sampling +- `meta-llama/Meta-Llama-3-8B `__ + on single HPU, or with tensor parallelism on 2x and 8x HPU, BF16 + datatype with random or greedy sampling +- `meta-llama/Meta-Llama-3-8B-Instruct `__ + on single HPU, or with tensor parallelism on 2x and 8x HPU, BF16 + datatype with random or greedy sampling +- `meta-llama/Meta-Llama-3.1-8B `__ + on single HPU, or with tensor parallelism on 2x and 8x HPU, BF16 + datatype with random or greedy sampling +- `meta-llama/Meta-Llama-3.1-8B-Instruct `__ + on single HPU, or with tensor parallelism on 2x and 8x HPU, BF16 + datatype with random or greedy sampling +- `meta-llama/Llama-2-70b `__ + with tensor parallelism on 8x HPU, BF16 datatype with random or greedy sampling +- `meta-llama/Llama-2-70b-chat-hf `__ + with tensor parallelism on 8x HPU, BF16 datatype with random or greedy sampling +- `meta-llama/Meta-Llama-3-70B `__ + with tensor parallelism on 8x HPU, BF16 datatype with random or greedy sampling +- `meta-llama/Meta-Llama-3-70B-Instruct `__ + with tensor parallelism on 8x HPU, BF16 datatype with random or greedy sampling +- `meta-llama/Meta-Llama-3.1-70B `__ + with tensor parallelism on 8x HPU, BF16 datatype with random or greedy sampling +- `meta-llama/Meta-Llama-3.1-70B-Instruct `__ + with tensor parallelism on 8x HPU, BF16 datatype with random or greedy sampling + +Performance Tuning +================== + +Execution modes +--------------- + +Currently in vLLM for HPU we support four execution modes, depending on selected HPU PyTorch Bridge backend (via ``PT_HPU_LAZY_MODE`` environment variable), and ``--enforce-eager`` flag. + +.. list-table:: vLLM execution modes + :widths: 25 25 50 + :header-rows: 1 + + * - ``PT_HPU_LAZY_MODE`` + - ``enforce_eager`` + - execution mode + * - 0 + - 0 + - torch.compile + * - 0 + - 1 + - PyTorch eager mode + * - 1 + - 0 + - HPU Graphs + * - 1 + - 1 + - PyTorch lazy mode + +.. warning:: + In 1.18.0, all modes utilizing ``PT_HPU_LAZY_MODE=0`` are highly experimental and should be only used for validating functional correctness. Their performance will be improved in the next releases. For obtaining the best performance in 1.18.0, please use HPU Graphs, or PyTorch lazy mode. + + +Bucketing mechanism +------------------- + +Intel Gaudi accelerators work best when operating on models with fixed tensor shapes. `Intel Gaudi Graph Compiler `__ is responsible for generating optimized binary code that implements the given model topology on Gaudi. In its default configuration, the produced binary code may be heavily dependent on input and output tensor shapes, and can require graph recompilation when encountering differently shaped tensors within the same topology. While the resulting binaries utilize Gaudi efficiently, the compilation itself may introduce a noticeable overhead in end-to-end execution. +In a dynamic inference serving scenario, there is a need to minimize the number of graph compilations and reduce the risk of graph compilation occurring during server runtime. Currently it is achieved by "bucketing" model's forward pass across two dimensions - ``batch_size`` and ``sequence_length``. + +.. note:: + Bucketing allows us to reduce the number of required graphs significantly, but it does not handle any graph compilation and device code generation - this is done in warmup and HPUGraph capture phase. + +Bucketing ranges are determined with 3 parameters - ``min``, ``step`` and ``max``. They can be set separately for prompt and decode phase, and for batch size and sequence length dimension. These parameters can be observed in logs during vLLM startup: + +.. code-block:: + + INFO 08-01 21:37:59 hpu_model_runner.py:493] Prompt bucket config (min, step, max_warmup) bs:[1, 32, 4], seq:[128, 128, 1024] + INFO 08-01 21:37:59 hpu_model_runner.py:499] Generated 24 prompt buckets: [(1, 128), (1, 256), (1, 384), (1, 512), (1, 640), (1, 768), (1, 896), (1, 1024), (2, 128), (2, 256), (2, 384), (2, 512), (2, 640), (2, 768), (2, 896), (2, 1024), (4, 128), (4, 256), (4, 384), (4, 512), (4, 640), (4, 768), (4, 896), (4, 1024)] + INFO 08-01 21:37:59 hpu_model_runner.py:504] Decode bucket config (min, step, max_warmup) bs:[1, 128, 4], seq:[128, 128, 2048] + INFO 08-01 21:37:59 hpu_model_runner.py:509] Generated 48 decode buckets: [(1, 128), (1, 256), (1, 384), (1, 512), (1, 640), (1, 768), (1, 896), (1, 1024), (1, 1152), (1, 1280), (1, 1408), (1, 1536), (1, 1664), (1, 1792), (1, 1920), (1, 2048), (2, 128), (2, 256), (2, 384), (2, 512), (2, 640), (2, 768), (2, 896), (2, 1024), (2, 1152), (2, 1280), (2, 1408), (2, 1536), (2, 1664), (2, 1792), (2, 1920), (2, 2048), (4, 128), (4, 256), (4, 384), (4, 512), (4, 640), (4, 768), (4, 896), (4, 1024), (4, 1152), (4, 1280), (4, 1408), (4, 1536), (4, 1664), (4, 1792), (4, 1920), (4, 2048)] + +``min`` determines the lowest value of the bucket. ``step`` determines the interval between buckets, and ``max`` determines the upper bound of the bucket. Furthermore, interval between ``min`` and ``step`` has special handling - ``min`` gets multiplied by consecutive powers of two, until ``step`` gets reached. We call this the ramp-up phase and it is used for handling lower batch sizes with minimum wastage, while allowing larger padding on larger batch sizes. + +Example (with ramp-up) + +.. code-block:: + + min = 2, step = 32, max = 64 + => ramp_up = (2, 4, 8, 16) + => stable = (32, 64) + => buckets = ramp_up + stable => (2, 4, 8, 16, 32, 64) + +Example (without ramp-up) + +.. code-block:: + + min = 128, step = 128, max = 512 + => ramp_up = () + => stable = (128, 256, 384, 512) + => buckets = ramp_up + stable => (128, 256, 384, 512) + + +In the logged scenario, 24 buckets were generated for prompt (prefill) runs, and 48 buckets for decode runs. Each bucket corresponds to a separate optimized device binary for a given model with specified tensor shapes. Whenever a batch of requests is processed, it is padded across batch and sequence length dimension to the smallest possible bucket. + +.. warning:: + If a request exceeds maximum bucket size in any dimension, it will be processed without padding, and its processing may require a graph compilation, potentially significantly increasing end-to-end latency. The boundaries of the buckets are user-configurable via environment variables, and upper bucket boundaries can be increased to avoid such scenario. + +As an example, if a request of 3 sequences, with max sequence length of 412 comes in to an idle vLLM server, it will be padded executed as ``(4, 512)`` prefill bucket, as ``batch_size`` (number of sequences) will be padded to 4 (closest batch_size dimension higher than 3), and max sequence length will be padded to 512 (closest sequence length dimension higher than 412). After prefill stage, it will be executed as ``(4, 512)`` decode bucket and will continue as that bucket until either batch dimension changes (due to request being finished) - in which case it will become a ``(2, 512)`` bucket, or context length increases above 512 tokens, in which case it will become ``(4, 640)`` bucket. + +.. note:: + Bucketing is transparent to a client - padding in sequence length dimension is never returned to the client, and padding in batch dimension does not create new requests. + +Warmup +------ + +Warmup is an optional, but highly recommended step occurring before vLLM server starts listening. It executes a forward pass for each bucket with dummy data. The goal is to pre-compile all graphs and not incur any graph compilation overheads within bucket boundaries during server runtime. Each warmup step is logged during vLLM startup: + +.. code-block:: + + INFO 08-01 22:26:47 hpu_model_runner.py:1066] [Warmup][Prompt][1/24] batch_size:4 seq_len:1024 free_mem:79.16 GiB + INFO 08-01 22:26:47 hpu_model_runner.py:1066] [Warmup][Prompt][2/24] batch_size:4 seq_len:896 free_mem:55.43 GiB + INFO 08-01 22:26:48 hpu_model_runner.py:1066] [Warmup][Prompt][3/24] batch_size:4 seq_len:768 free_mem:55.43 GiB + ... + INFO 08-01 22:26:59 hpu_model_runner.py:1066] [Warmup][Prompt][24/24] batch_size:1 seq_len:128 free_mem:55.43 GiB + INFO 08-01 22:27:00 hpu_model_runner.py:1066] [Warmup][Decode][1/48] batch_size:4 seq_len:2048 free_mem:55.43 GiB + INFO 08-01 22:27:00 hpu_model_runner.py:1066] [Warmup][Decode][2/48] batch_size:4 seq_len:1920 free_mem:55.43 GiB + INFO 08-01 22:27:01 hpu_model_runner.py:1066] [Warmup][Decode][3/48] batch_size:4 seq_len:1792 free_mem:55.43 GiB + ... + INFO 08-01 22:27:16 hpu_model_runner.py:1066] [Warmup][Decode][47/48] batch_size:2 seq_len:128 free_mem:55.43 GiB + INFO 08-01 22:27:16 hpu_model_runner.py:1066] [Warmup][Decode][48/48] batch_size:1 seq_len:128 free_mem:55.43 GiB + +This example uses the same buckets as in *Bucketing mechanism* section. Each output line corresponds to execution of a single bucket. When bucket is executed for the first time, its graph is compiled and can be reused later on, skipping further graph compilations. + +.. tip:: + Compiling all the buckets might take some time and can be turned off with ``VLLM_SKIP_WARMUP=true`` environment variable. Keep in mind that if you do that, you may face graph compilations once executing a given bucket for the first time. It is fine to disable warmup for development, but it's highly recommended to enable it in deployment. + +HPU Graph capture +----------------- + +`HPU Graphs `__ are currently the most performant execution method of vLLM on Intel Gaudi. When HPU Graphs are enabled, execution graphs will be traced (recorded) ahead of time (after performing warmup), to be later replayed during inference, significantly reducing host overheads. Recording can take large amounts of memory, which needs to be taken into account when allocating KV cache. Enabling HPU Graphs will impact the number of available KV cache blocks, but vLLM provides user-configurable variables to control memory management. + + +When HPU Graphs are being used, they share the common memory pool ("usable memory") as KV cache, determined by ``gpu_memory_utilization`` flag (``0.9`` by default). +Before KV cache gets allocated, model weights are loaded onto the device, and a forward pass of the model is executed on dummy data, to estimate memory usage. +Only after that, ``gpu_memory_utilization`` flag is utilized - at its default value, will mark 90% of free device memory at that point as usable. +Next, KV cache gets allocated, model is warmed up, and HPU Graphs are captured. +Environment variable ``VLLM_GRAPH_RESERVED_MEM`` defines the ratio of memory reserved for HPU Graphs capture. +With its default value (``VLLM_GRAPH_RESERVED_MEM=0.1``), 10% of usable memory will be reserved for graph capture (later referred to as "usable graph memory"), and the remaining 90% will be utilized for KV cache. +Environment variable ``VLLM_GRAPH_PROMPT_RATIO`` determines the ratio of usable graph memory reserved for prefill and decode graphs. By default (``VLLM_GRAPH_PROMPT_RATIO=0.3``), both stages have equal memory constraints. +Lower value corresponds to less usable graph memory reserved for prefill stage, e.g. ``VLLM_GRAPH_PROMPT_RATIO=0.2`` will reserve 20% of usable graph memory for prefill graphs, and 80% of usable graph memory for decode graphs. + +.. note:: + ``gpu_memory_utilization`` does not correspond to the absolute memory usage across HPU. It specifies the memory margin after loading the model and performing a profile run. If device has 100 GiB of total memory, and 50 GiB of free memory after loading model weights and executing profiling run, ``gpu_memory_utilization`` at its default value will mark 90% of 50 GiB as usable, leaving 5 GiB of margin, regardless of total device memory. + +User can also configure the strategy for capturing HPU Graphs for prompt and decode stages separately. Strategy affects the order of capturing graphs. There are two strategies implemented: +- ``max_bs`` - graph capture queue will sorted in descending order by their batch sizes. Buckets with equal batch sizes are sorted by sequence length in ascending order (e.g. ``(64, 128)``, ``(64, 256)``, ``(32, 128)``, ``(32, 256)``, ``(1, 128)``, ``(1,256)``), default strategy for decode +- ``min_tokens`` - graph capture queue will be sorted in ascending order by the number of tokens each graph processes (``batch_size*sequence_length``), default strategy for prompt + +When there's large amount of requests pending, vLLM scheduler will attempt to fill the maximum batch size for decode as soon as possible. When a request is finished, decode batch size decreases. When that happens, vLLM will attempt to schedule a prefill iteration for requests in the waiting queue, to fill the decode batch size to its previous state. This means that in a full load scenario, decode batch size is often at its maximum, which makes large batch size HPU Graphs crucial to capture, as reflected by ``max_bs`` strategy. On the other hand, prefills will be executed most frequently with very low batch sizes (1-4), which is reflected in ``min_tokens`` strategy. + + +.. note:: + ``VLLM_GRAPH_PROMPT_RATIO`` does not set a hard limit on memory taken by graphs for each stage (prefill and decode). vLLM will first attempt to use up entirety of usable prefill graph memory (usable graph memory * ``VLLM_GRAPH_PROMPT_RATIO``) for capturing prefill HPU Graphs, next it will attempt do the same for decode graphs and usable decode graph memory pool. If one stage is fully captured, and there is unused memory left within usable graph memory pool, vLLM will attempt further graph capture for the other stage, until no more HPU Graphs can be captured without exceeding reserved memory pool. The behavior on that mechanism can be observed in the example below. + + +Each described step is logged by vLLM server, as follows (negative values correspond to memory being released): + +.. code-block:: + + INFO 08-02 17:37:44 hpu_model_runner.py:493] Prompt bucket config (min, step, max_warmup) bs:[1, 32, 4], seq:[128, 128, 1024] + INFO 08-02 17:37:44 hpu_model_runner.py:499] Generated 24 prompt buckets: [(1, 128), (1, 256), (1, 384), (1, 512), (1, 640), (1, 768), (1, 896), (1, 1024), (2, 128), (2, 256), (2, 384), (2, 512), (2, 640), (2, 768), (2, 896), (2, 1024), (4, 128), (4, 256), (4, 384), (4, 512), (4, 640), (4, 768), (4, 896), (4, 1024)] + INFO 08-02 17:37:44 hpu_model_runner.py:504] Decode bucket config (min, step, max_warmup) bs:[1, 128, 4], seq:[128, 128, 2048] + INFO 08-02 17:37:44 hpu_model_runner.py:509] Generated 48 decode buckets: [(1, 128), (1, 256), (1, 384), (1, 512), (1, 640), (1, 768), (1, 896), (1, 1024), (1, 1152), (1, 1280), (1, 1408), (1, 1536), (1, 1664), (1, 1792), (1, 1920), (1, 2048), (2, 128), (2, 256), (2, 384), (2, 512), (2, 640), (2, 768), (2, 896), (2, 1024), (2, 1152), (2, 1280), (2, 1408), (2, 1536), (2, 1664), (2, 1792), (2, 1920), (2, 2048), (4, 128), (4, 256), (4, 384), (4, 512), (4, 640), (4, 768), (4, 896), (4, 1024), (4, 1152), (4, 1280), (4, 1408), (4, 1536), (4, 1664), (4, 1792), (4, 1920), (4, 2048)] + INFO 08-02 17:37:52 hpu_model_runner.py:430] Pre-loading model weights on hpu:0 took 14.97 GiB of device memory (14.97 GiB/94.62 GiB used) and 2.95 GiB of host memory (475.2 GiB/1007 GiB used) + INFO 08-02 17:37:52 hpu_model_runner.py:438] Wrapping in HPU Graph took 0 B of device memory (14.97 GiB/94.62 GiB used) and -252 KiB of host memory (475.2 GiB/1007 GiB used) + INFO 08-02 17:37:52 hpu_model_runner.py:442] Loading model weights took in total 14.97 GiB of device memory (14.97 GiB/94.62 GiB used) and 2.95 GiB of host memory (475.2 GiB/1007 GiB used) + INFO 08-02 17:37:54 hpu_worker.py:134] Model profiling run took 504 MiB of device memory (15.46 GiB/94.62 GiB used) and 180.9 MiB of host memory (475.4 GiB/1007 GiB used) + INFO 08-02 17:37:54 hpu_worker.py:158] Free device memory: 79.16 GiB, 39.58 GiB usable (gpu_memory_utilization=0.5), 15.83 GiB reserved for HPUGraphs (VLLM_GRAPH_RESERVED_MEM=0.4), 23.75 GiB reserved for KV cache + INFO 08-02 17:37:54 hpu_executor.py:85] # HPU blocks: 1519, # CPU blocks: 0 + INFO 08-02 17:37:54 hpu_worker.py:190] Initializing cache engine took 23.73 GiB of device memory (39.2 GiB/94.62 GiB used) and -1.238 MiB of host memory (475.4 GiB/1007 GiB used) + INFO 08-02 17:37:54 hpu_model_runner.py:1066] [Warmup][Prompt][1/24] batch_size:4 seq_len:1024 free_mem:55.43 GiB + ... + INFO 08-02 17:38:22 hpu_model_runner.py:1066] [Warmup][Decode][48/48] batch_size:1 seq_len:128 free_mem:55.43 GiB + INFO 08-02 17:38:22 hpu_model_runner.py:1159] Using 15.85 GiB/55.43 GiB of free device memory for HPUGraphs, 7.923 GiB for prompt and 7.923 GiB for decode (VLLM_GRAPH_PROMPT_RATIO=0.3) + INFO 08-02 17:38:22 hpu_model_runner.py:1066] [Warmup][Graph/Prompt][1/24] batch_size:1 seq_len:128 free_mem:55.43 GiB + ... + INFO 08-02 17:38:26 hpu_model_runner.py:1066] [Warmup][Graph/Prompt][11/24] batch_size:1 seq_len:896 free_mem:48.77 GiB + INFO 08-02 17:38:27 hpu_model_runner.py:1066] [Warmup][Graph/Decode][1/48] batch_size:4 seq_len:128 free_mem:47.51 GiB + ... + INFO 08-02 17:38:41 hpu_model_runner.py:1066] [Warmup][Graph/Decode][48/48] batch_size:1 seq_len:2048 free_mem:47.35 GiB + INFO 08-02 17:38:41 hpu_model_runner.py:1066] [Warmup][Graph/Prompt][12/24] batch_size:4 seq_len:256 free_mem:47.35 GiB + INFO 08-02 17:38:42 hpu_model_runner.py:1066] [Warmup][Graph/Prompt][13/24] batch_size:2 seq_len:512 free_mem:45.91 GiB + INFO 08-02 17:38:42 hpu_model_runner.py:1066] [Warmup][Graph/Prompt][14/24] batch_size:1 seq_len:1024 free_mem:44.48 GiB + INFO 08-02 17:38:43 hpu_model_runner.py:1066] [Warmup][Graph/Prompt][15/24] batch_size:2 seq_len:640 free_mem:43.03 GiB + INFO 08-02 17:38:43 hpu_model_runner.py:1128] Graph/Prompt captured:15 (62.5%) used_mem:14.03 GiB buckets:[(1, 128), (1, 256), (1, 384), (1, 512), (1, 640), (1, 768), (1, 896), (1, 1024), (2, 128), (2, 256), (2, 384), (2, 512), (2, 640), (4, 128), (4, 256)] + INFO 08-02 17:38:43 hpu_model_runner.py:1128] Graph/Decode captured:48 (100.0%) used_mem:161.9 MiB buckets:[(1, 128), (1, 256), (1, 384), (1, 512), (1, 640), (1, 768), (1, 896), (1, 1024), (1, 1152), (1, 1280), (1, 1408), (1, 1536), (1, 1664), (1, 1792), (1, 1920), (1, 2048), (2, 128), (2, 256), (2, 384), (2, 512), (2, 640), (2, 768), (2, 896), (2, 1024), (2, 1152), (2, 1280), (2, 1408), (2, 1536), (2, 1664), (2, 1792), (2, 1920), (2, 2048), (4, 128), (4, 256), (4, 384), (4, 512), (4, 640), (4, 768), (4, 896), (4, 1024), (4, 1152), (4, 1280), (4, 1408), (4, 1536), (4, 1664), (4, 1792), (4, 1920), (4, 2048)] + INFO 08-02 17:38:43 hpu_model_runner.py:1206] Warmup finished in 49 secs, allocated 14.19 GiB of device memory + INFO 08-02 17:38:43 hpu_executor.py:91] init_cache_engine took 37.92 GiB of device memory (53.39 GiB/94.62 GiB used) and 57.86 MiB of host memory (475.4 GiB/1007 GiB used) + + +Recommended vLLM Parameters +--------------------------- + +- We recommend running inference on Gaudi 2 with ``block_size`` of 128 + for BF16 data type. Using default values (16, 32) might lead to + sub-optimal performance due to Matrix Multiplication Engine + under-utilization (see `Gaudi + Architecture `__). +- For max throughput on Llama 7B, we recommend running with batch size + of 128 or 256 and max context length of 2048 with HPU Graphs enabled. + If you encounter out-of-memory issues, see troubleshooting section. + +Environment variables +--------------------- + +**Diagnostic and profiling knobs:** + +- ``VLLM_PROFILER_ENABLED``: if ``true``, high level profiler will be enabled. Resulting JSON traces can be viewed in `perfetto.habana.ai `__. Disabled by default. +- ``VLLM_HPU_LOG_STEP_GRAPH_COMPILATION``: if ``true``, will log graph compilations per each vLLM engine step, only when there was any - highly recommended to use alongside ``PT_HPU_METRICS_GC_DETAILS=1``. Disabled by default. +- ``VLLM_HPU_LOG_STEP_GRAPH_COMPILATION_ALL``: if ``true``, will log graph compilations per each vLLM engine step, always, even if there were none. Disabled by default. +- ``VLLM_HPU_LOG_STEP_CPU_FALLBACKS``: if ``true``, will log cpu fallbacks per each vLLM engine step, only when there was any. Disabled by default. +- ``VLLM_HPU_LOG_STEP_CPU_FALLBACKS_ALL``: if ``true``, will log cpu fallbacks per each vLLM engine step, always, even if there were none. Disabled by default. + +**Performance tuning knobs:** + +- ``VLLM_SKIP_WARMUP``: if ``true``, warmup will be skipped, ``false`` by default +- ``VLLM_GRAPH_RESERVED_MEM``: percentage of memory dedicated for HPUGraph capture, ``0.1`` by default +- ``VLLM_GRAPH_PROMPT_RATIO``: percentage of reserved graph memory dedicated for prompt graphs, ``0.3`` by default +- ``VLLM_GRAPH_PROMPT_STRATEGY``: strategy determining order of prompt graph capture, ``min_tokens`` or ``max_bs``, ``min_tokens`` by default +- ``VLLM_GRAPH_DECODE_STRATEGY``: strategy determining order of decode graph capture, ``min_tokens`` or ``max_bs``, ``max_bs`` by default +- ``VLLM_{phase}_{dim}_BUCKET_{param}`` - collection of 12 environment variables configuring ranges of bucketing mechanism + + - ``{phase}`` is either ``PROMPT`` or ``DECODE`` + - ``{dim}`` is either ``BS``, ``SEQ`` or ``BLOCK`` + - ``{param}`` is either ``MIN``, ``STEP`` or ``MAX`` + - Default values: + + - Prompt: + - batch size min (``VLLM_PROMPT_BS_BUCKET_MIN``): ``1`` + - batch size step (``VLLM_PROMPT_BS_BUCKET_STEP``): ``min(max_num_seqs, 32)`` + - batch size max (``VLLM_PROMPT_BS_BUCKET_MAX``): ``min(max_num_seqs, 64)`` + - sequence length min (``VLLM_PROMPT_SEQ_BUCKET_MIN``): ``block_size`` + - sequence length step (``VLLM_PROMPT_SEQ_BUCKET_STEP``): ``block_size`` + - sequence length max (``VLLM_PROMPT_SEQ_BUCKET_MAX``): ``max_model_len`` + + - Decode: + - batch size min (``VLLM_DECODE_BS_BUCKET_MIN``): ``1`` + - batch size step (``VLLM_DECODE_BS_BUCKET_STEP``): ``min(max_num_seqs, 32)`` + - batch size max (``VLLM_DECODE_BS_BUCKET_MAX``): ``max_num_seqs`` + - sequence length min (``VLLM_DECODE_BLOCK_BUCKET_MIN``): ``block_size`` + - sequence length step (``VLLM_DECODE_BLOCK_BUCKET_STEP``): ``block_size`` + - sequence length max (``VLLM_DECODE_BLOCK_BUCKET_MAX``): ``max(128, (max_num_seqs*max_model_len)/block_size)`` + + +Additionally, there are HPU PyTorch Bridge environment variables impacting vLLM execution: + +- ``PT_HPU_LAZY_MODE``: if ``0``, PyTorch Eager backend for Gaudi will be used, if ``1`` PyTorch Lazy backend for Gaudi will be used, ``1`` is default +- ``PT_HPU_ENABLE_LAZY_COLLECTIVES``: required to be ``true`` for tensor parallel inference with HPU Graphs + +Troubleshooting: Tweaking HPU Graphs +==================================== + +If you experience device out-of-memory issues or want to attempt +inference at higher batch sizes, try tweaking HPU Graphs by following +the below: + +- Tweak ``gpu_memory_utilization`` knob. It will decrease the + allocation of KV cache, leaving some headroom for capturing graphs + with larger batch size. By default ``gpu_memory_utilization`` is set + to 0.9. It attempts to allocate ~90% of HBM left for KV cache after + short profiling run. Note that decreasing reduces the number of KV + cache blocks you have available, and therefore reduces the effective + maximum number of tokens you can handle at a given time. + +- If this method is not efficient, you can disable ``HPUGraph`` + completely. With HPU Graphs disabled, you are trading latency and + throughput at lower batches for potentially higher throughput on + higher batches. You can do that by adding ``--enforce-eager`` flag to + server (for online inference), or by passing ``enforce_eager=True`` + argument to LLM constructor (for offline inference). diff --git a/docs/source/index.rst b/docs/source/index.rst index d20e46b4a3656..ac95e660e1740 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -43,7 +43,7 @@ vLLM is flexible and easy to use with: * Tensor parallelism and pipeline parallelism support for distributed inference * Streaming outputs * OpenAI-compatible API server -* Support NVIDIA GPUs, AMD CPUs and GPUs, Intel CPUs and GPUs, PowerPC CPUs, TPU, and AWS Trainium and Inferentia Accelerators. +* Support NVIDIA GPUs, AMD CPUs and GPUs, Intel CPUs and Gaudi® accelerators, GPUs, PowerPC CPUs, TPU, and AWS Trainium and Inferentia Accelerators. * Prefix caching support * Multi-lora support @@ -66,6 +66,7 @@ Documentation getting_started/amd-installation getting_started/openvino-installation getting_started/cpu-installation + getting_started/gaudi-installation getting_started/neuron-installation getting_started/tpu-installation getting_started/xpu-installation diff --git a/requirements-hpu.txt b/requirements-hpu.txt new file mode 100644 index 0000000000000..1a583974be151 --- /dev/null +++ b/requirements-hpu.txt @@ -0,0 +1,11 @@ +# Common dependencies +-r requirements-common.txt + +# Dependencies for HPU code +ray == 2.32.0 +triton +pandas +tabulate +setuptools>=61 +setuptools-scm>=8 +vllm-hpu-extension @ git+https://github.com/HabanaAI/vllm-hpu-extension.git@fd7f2e6 diff --git a/setup.py b/setup.py index 9ea4e85c07542..bd114d4b08c31 100644 --- a/setup.py +++ b/setup.py @@ -246,6 +246,24 @@ def run(self): self.copy_file(file, dst_file) +def _is_hpu() -> bool: + is_hpu_available = True + try: + subprocess.run(["hl-smi"], capture_output=True, check=True) + except (FileNotFoundError, PermissionError, subprocess.CalledProcessError): + if not os.path.exists('/dev/accel/accel0') and not os.path.exists( + '/dev/accel/accel_controlD0'): + # last resort... + try: + output = subprocess.check_output( + 'lsmod | grep habanalabs | wc -l', shell=True) + is_hpu_available = int(output) > 0 + except (ValueError, FileNotFoundError, PermissionError, + subprocess.CalledProcessError): + is_hpu_available = False + return is_hpu_available or VLLM_TARGET_DEVICE == "hpu" + + def _no_device() -> bool: return VLLM_TARGET_DEVICE == "empty" @@ -253,7 +271,7 @@ def _no_device() -> bool: def _is_cuda() -> bool: has_cuda = torch.version.cuda is not None return (VLLM_TARGET_DEVICE == "cuda" and has_cuda - and not (_is_neuron() or _is_tpu())) + and not (_is_neuron() or _is_tpu() or _is_hpu())) def _is_hip() -> bool: @@ -291,7 +309,8 @@ def _build_custom_ops() -> bool: def _build_core_ext() -> bool: - return not (_is_neuron() or _is_tpu() or _is_openvino() or _is_xpu()) + return not (_is_neuron() or _is_tpu() or _is_openvino() or _is_xpu() + or _is_hpu()) def get_hipcc_rocm_version(): @@ -353,6 +372,23 @@ def get_path(*filepath) -> str: return os.path.join(ROOT_DIR, *filepath) +def get_gaudi_sw_version(): + """ + Returns the driver version. + """ + # Enable console printing for `hl-smi` check + output = subprocess.run("hl-smi", + shell=True, + text=True, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + env={"ENABLE_CONSOLE": "true"}) + if output.returncode == 0 and output.stdout: + return output.stdout.split("\n")[2].replace( + " ", "").split(":")[1][:-1].split("-")[0] + return "0.0.0" # when hl-smi is not available + + def get_vllm_version() -> str: version = get_version( write_to="vllm/_version.py", # TODO: move this to pyproject.toml @@ -382,6 +418,12 @@ def get_vllm_version() -> str: if neuron_version != MAIN_CUDA_VERSION: neuron_version_str = neuron_version.replace(".", "")[:3] version += f"{sep}neuron{neuron_version_str}" + elif _is_hpu(): + # Get the Intel Gaudi Software Suite version + gaudi_sw_version = str(get_gaudi_sw_version()) + if gaudi_sw_version != MAIN_CUDA_VERSION: + gaudi_sw_version = gaudi_sw_version.replace(".", "")[:3] + version += f"{sep}gaudi{gaudi_sw_version}" elif _is_openvino(): version += f"{sep}openvino" elif _is_tpu(): @@ -439,6 +481,8 @@ def _read_requirements(filename: str) -> List[str]: requirements = _read_requirements("requirements-rocm.txt") elif _is_neuron(): requirements = _read_requirements("requirements-neuron.txt") + elif _is_hpu(): + requirements = _read_requirements("requirements-hpu.txt") elif _is_openvino(): requirements = _read_requirements("requirements-openvino.txt") elif _is_tpu(): @@ -449,7 +493,7 @@ def _read_requirements(filename: str) -> List[str]: requirements = _read_requirements("requirements-xpu.txt") else: raise ValueError( - "Unsupported platform, please use CUDA, ROCm, Neuron, " + "Unsupported platform, please use CUDA, ROCm, Neuron, HPU, " "OpenVINO, or CPU.") return requirements diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index ec035f137c3a6..1b3b4c0979522 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -12,7 +12,7 @@ logger = init_logger(__name__) -if not current_platform.is_tpu(): +if not current_platform.is_tpu() and not current_platform.is_hpu(): try: import vllm._C except ImportError as e: diff --git a/vllm/attention/backends/hpu_attn.py b/vllm/attention/backends/hpu_attn.py new file mode 100644 index 0000000000000..a8f4b09b67274 --- /dev/null +++ b/vllm/attention/backends/hpu_attn.py @@ -0,0 +1,264 @@ +############################################################################### +# Copyright (C) 2024 Habana Labs, Ltd. an Intel Company +############################################################################### + +import os +from dataclasses import dataclass +from typing import Any, Dict, List, Optional, Tuple, Type + +import torch +import vllm_hpu_extension.ops as ops +from vllm_hpu_extension.utils import Matmul, Softmax, VLLMKVCache + +from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, + AttentionMetadata, AttentionType) +from vllm.attention.backends.utils import CommonAttentionState +from vllm.attention.ops.hpu_paged_attn import (HPUPagedAttention, + HPUPagedAttentionMetadata) +from vllm.logger import init_logger + +logger = init_logger(__name__) + + +class HPUAttentionBackend(AttentionBackend): + + @staticmethod + def get_impl_cls() -> Type["HPUAttentionImpl"]: + return HPUAttentionImpl + + @staticmethod + def get_metadata_cls() -> Type["AttentionMetadata"]: + return HPUAttentionMetadata + + @staticmethod + def get_state_cls() -> Type["CommonAttentionState"]: + return CommonAttentionState + + @staticmethod + def get_kv_cache_shape( + num_blocks: int, + block_size: int, + num_kv_heads: int, + head_size: int, + ) -> Tuple[int, ...]: + return HPUPagedAttention.get_kv_cache_shape(num_blocks, block_size, + num_kv_heads, head_size) + + @staticmethod + def swap_blocks( + src_kv_cache: torch.Tensor, + dst_kv_cache: torch.Tensor, + src_to_dst: Dict[int, int], + ) -> None: + HPUPagedAttention.swap_blocks(src_kv_cache, dst_kv_cache, src_to_dst) + + @staticmethod + def copy_blocks( + kv_caches: List[torch.Tensor], + src_to_dists: Dict[int, List[int]], + ) -> None: + HPUPagedAttention.copy_blocks(kv_caches, src_to_dists) + + +@dataclass +class HPUAttentionMetadata(HPUPagedAttentionMetadata, AttentionMetadata): + """Metadata for HPUAttentionbackend.""" + # Currently, input sequences can only contain all prompts + # or all decoding. True if all sequences are prompts. + is_prompt: bool + attn_bias: Optional[torch.Tensor] + seq_lens_tensor: Optional[torch.Tensor] + + +class HPUAttentionImpl(AttentionImpl, torch.nn.Module): + """ + If the input tensors contain prompt tokens, the layout is as follows: + |<--------------- num_prefill_tokens ----------------->| + |<--prefill_0-->|<--prefill_1-->|...|<--prefill_N-1--->| + + Otherwise, the layout is as follows: + |<----------------- num_decode_tokens ------------------>| + |<--decode_0-->|..........|<--decode_M-1-->|<--padding-->| + + Generation tokens can contain padding when cuda-graph is used. + Currently, prompt tokens don't contain any padding. + + The prompts might have different lengths, while the generation tokens + always have length 1. + """ + + def __init__( + self, + num_heads: int, + head_size: int, + scale: float, + num_kv_heads: int, + alibi_slopes: Optional[List[float]], + sliding_window: Optional[int], + kv_cache_dtype: str, + blocksparse_params: Optional[Dict[str, Any]] = None, + max_seq_len: int = 4096, + ) -> None: + super(AttentionImpl, self).__init__() + self.kv_cache_dtype = kv_cache_dtype + self.num_heads = num_heads + self.head_size = head_size + self.scale = float(scale) + self.matmul_qk = Matmul() + self.softmax = Softmax() + self.matmul_av = Matmul() + self.k_cache = VLLMKVCache() + self.v_cache = VLLMKVCache() + self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads + self.sliding_window = sliding_window + self.alibi_slopes = alibi_slopes + if alibi_slopes is not None: + alibi_slopes_tensor = torch.tensor(alibi_slopes, + dtype=torch.bfloat16) + self.alibi_slopes = alibi_slopes_tensor + assert self.num_heads % self.num_kv_heads == 0 + self.num_queries_per_kv = self.num_heads // self.num_kv_heads + + self.prefill_usefusedsdpa = os.getenv('VLLM_PROMPT_USE_FUSEDSDPA', + '0').lower() in ['1', 'true'] + if self.prefill_usefusedsdpa: + assert alibi_slopes is None, \ + 'Prefill with FusedSDPA not supported with alibi slopes!' + + suppored_head_sizes = HPUPagedAttention.get_supported_head_sizes() + if head_size not in suppored_head_sizes: + raise ValueError( + f"Head size {head_size} is not supported by PagedAttention. " + f"Supported head sizes are: {suppored_head_sizes}.") + + def forward( + self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + kv_cache: torch.Tensor, + attn_metadata: HPUAttentionMetadata, + k_scale: float = 1.0, + v_scale: float = 1.0, + attn_type: AttentionType = AttentionType.DECODER, + ) -> torch.Tensor: + """Forward pass with xFormers and PagedAttention. + + Args: + query: shape = [num_tokens, num_heads * head_size] + key: shape = [num_tokens, num_kv_heads * head_size] + value: shape = [num_tokens, num_kv_heads * head_size] + kv_cache = [2, num_blocks, block_size * num_kv_heads * head_size] + attn_metadata: Metadata for attention. + Returns: + shape = [num_tokens, num_heads * head_size] + """ + if attn_type != AttentionType.DECODER: + raise NotImplementedError("Encoder self-attention and " + "encoder/decoder cross-attention " + "are not implemented for " + "HPUAttentionImpl") + batch_size, seq_len, hidden_size = query.shape + _, seq_len_kv, _ = key.shape + + query = query.view(-1, self.num_heads, self.head_size) + key = key.view(-1, self.num_kv_heads, self.head_size) + value = value.view(-1, self.num_kv_heads, self.head_size) + block_indices = attn_metadata.block_indices + block_offsets = attn_metadata.block_offsets + if attn_metadata.is_prompt: + key = key.unflatten(0, (block_indices.size(0), -1)) + value = value.unflatten(0, (block_indices.size(0), -1)) + if kv_cache is not None: + key_cache, value_cache = HPUPagedAttention.split_kv_cache( + kv_cache, self.num_kv_heads, self.head_size) + + # Reshape the input keys and values and store them in the cache. + # If kv_cache is not provided, the new key and value tensors are + # not cached. This happens during the initial memory profiling run. + key_cache = self.k_cache(key, key_cache, block_indices, + block_offsets) + value_cache = self.v_cache(value, value_cache, block_indices, + block_offsets) + + if attn_metadata.is_prompt: + # Prompt run. + if not self.prefill_usefusedsdpa: + # TODO: move this outside of model + assert attn_metadata.attn_bias is not None, \ + 'attn_bias must be set before calling model.forward!' + attn_bias = attn_metadata.attn_bias + if self.alibi_slopes is not None: + position_bias = _make_alibi_bias(self.alibi_slopes, + self.num_kv_heads, + attn_bias.dtype, + attn_bias.shape[-1]) + attn_bias = attn_bias.tile((1, self.num_kv_heads, 1, 1)) + attn_bias.add_(position_bias) + else: + attn_bias = None + + query_shape = (batch_size, seq_len, self.num_heads, self.head_size) + kv_shape = (batch_size, seq_len_kv, self.num_kv_heads, + self.head_size) + out = ops.prompt_attention( + query.view(query_shape), + key.view(kv_shape), + value.view(kv_shape), + attn_bias=attn_bias, + p=0.0, + scale=self.scale, + matmul_qk_op=self.matmul_qk, + softmax_op=self.softmax, + matmul_av_op=self.matmul_av, + ) + output = out.reshape(batch_size, seq_len, hidden_size) + else: + # Decoding run. + output = HPUPagedAttention.forward_decode( + query=query, + key_cache=key_cache, + value_cache=value_cache, + block_list=attn_metadata.block_list, + block_mapping=attn_metadata.block_mapping, + block_bias=attn_metadata.attn_bias, + block_scales=attn_metadata.block_scales, + scale=self.scale, + matmul_qk_op=self.matmul_qk, + matmul_av_op=self.matmul_av, + keys_fetch_func=self.k_cache.fetch_from_cache, + values_fetch_func=self.v_cache.fetch_from_cache) + # Reshape the output tensor. + return output.view(batch_size, seq_len, hidden_size) + + +def _make_alibi_bias( + alibi_slopes: torch.Tensor, + num_kv_heads: int, + dtype: torch.dtype, + seq_len: int, +) -> torch.Tensor: + bias = torch.arange(seq_len, dtype=dtype) + # NOTE(zhuohan): HF uses + # `bias = bias[None, :].repeat(seq_len, 1)` + # here. We find that both biases give the same results, but + # the bias below more accurately follows the original ALiBi + # paper. + # Calculate a matrix where each element represents ith element- jth + # element. + bias = bias[None, :] - bias[:, None] + + padded_len = (seq_len + 7) // 8 * 8 + num_heads = alibi_slopes.shape[0] + bias = torch.empty( + 1, # batch size + num_heads, + seq_len, + padded_len, + device=alibi_slopes.device, + dtype=dtype, + )[:, :, :, :seq_len].copy_(bias) + bias.mul_(alibi_slopes[:, None, None]) + if num_heads != num_kv_heads: + bias = bias.unflatten(1, (num_kv_heads, num_heads // num_kv_heads)) + return bias diff --git a/vllm/attention/ops/hpu_paged_attn.py b/vllm/attention/ops/hpu_paged_attn.py new file mode 100644 index 0000000000000..4c0fb2a628361 --- /dev/null +++ b/vllm/attention/ops/hpu_paged_attn.py @@ -0,0 +1,103 @@ +############################################################################### +# Copyright (C) 2024 Habana Labs, Ltd. an Intel Company +############################################################################### + +from dataclasses import dataclass +from typing import Dict, List, Optional, Tuple + +import torch +from vllm_hpu_extension import cache_ops, ops + +# Should be the same as PARTITION_SIZE in `paged_attention_v2_launcher`. +_PARTITION_SIZE = 512 + + +@dataclass +class HPUPagedAttentionMetadata: + """Metadata for PagedAttention.""" + block_list: Optional[torch.Tensor] + block_mapping: Optional[torch.Tensor] + block_usage: Optional[torch.Tensor] + block_indices: Optional[torch.Tensor] + block_offsets: Optional[torch.Tensor] + block_scales: Optional[torch.Tensor] + + +class HPUPagedAttention: + + @staticmethod + def get_supported_head_sizes() -> List[int]: + return [64, 80, 96, 112, 128, 256] + + @staticmethod + def get_kv_cache_shape( + num_blocks: int, + block_size: int, + num_kv_heads: int, + head_size: int, + ) -> Tuple[int, ...]: + return (num_blocks, block_size, num_kv_heads, head_size) + + @staticmethod + def split_kv_cache( + kv_cache: torch.Tensor, + num_kv_heads: int, + head_size: int, + ) -> Tuple[torch.Tensor, torch.Tensor]: + key_cache = kv_cache[0] + value_cache = kv_cache[1] + return key_cache, value_cache + + @staticmethod + def write_to_paged_cache(key: torch.Tensor, value: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + slot_mapping: torch.Tensor, kv_cache_dtype: str, + is_prompt: bool) -> None: + cache_ops.reshape_and_cache(key, value, key_cache, value_cache, + slot_mapping, kv_cache_dtype, is_prompt) + + @staticmethod + def forward_decode(**kwargs) -> torch.Tensor: + return ops.flat_pa(**kwargs) + + @staticmethod + def forward_prefix( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + block_tables: torch.Tensor, + subquery_start_loc: torch.Tensor, + seq_lens_tensor: torch.Tensor, + context_lens: torch.Tensor, + max_query_len: int, + alibi_slopes: Optional[torch.Tensor], + sliding_window: Optional[int], + ) -> torch.Tensor: + raise NotImplementedError( + "forward_prefix is not implemented for HPUPagedAttention") + + @staticmethod + def swap_blocks( + src_kv_cache: torch.Tensor, + dst_kv_cache: torch.Tensor, + src_to_dst: Dict[int, int], + ) -> None: + src_key_cache = src_kv_cache[0] + dst_key_cache = dst_kv_cache[0] + cache_ops.swap_blocks(src_key_cache, dst_key_cache, src_to_dst) + + src_value_cache = src_kv_cache[1] + dst_value_cache = dst_kv_cache[1] + cache_ops.swap_blocks(src_value_cache, dst_value_cache, src_to_dst) + + @staticmethod + def copy_blocks( + kv_caches: List[torch.Tensor], + src_to_dists: Dict[int, List[int]], + ) -> None: + key_caches = [kv_cache[0] for kv_cache in kv_caches] + value_caches = [kv_cache[1] for kv_cache in kv_caches] + cache_ops.copy_blocks(key_caches, value_caches, src_to_dists) diff --git a/vllm/attention/selector.py b/vllm/attention/selector.py index 7edb7676ea2cd..e88f35f808fc8 100644 --- a/vllm/attention/selector.py +++ b/vllm/attention/selector.py @@ -22,6 +22,7 @@ class _Backend(enum.Enum): TORCH_SDPA = enum.auto() OPENVINO = enum.auto() FLASHINFER = enum.auto() + HPU_ATTN = enum.auto() PALLAS = enum.auto() IPEX = enum.auto() NO_ATTENTION = enum.auto() @@ -141,6 +142,10 @@ def get_attn_backend( logger.info("Using Flashinfer backend.") from vllm.attention.backends.flashinfer import FlashInferBackend return FlashInferBackend + elif backend == _Backend.HPU_ATTN: + logger.info("Using HPUAttention backend.") + from vllm.attention.backends.hpu_attn import HPUAttentionBackend + return HPUAttentionBackend elif backend == _Backend.PALLAS: logger.info("Using Pallas backend.") from vllm.attention.backends.pallas import PallasAttentionBackend @@ -217,6 +222,9 @@ def which_attn_to_use( logger.info("%s is not supported in AMD GPUs.", selected_backend) return _Backend.ROCM_FLASH + if current_platform.is_hpu(): + return _Backend.HPU_ATTN + # FlashAttn in NVIDIA GPUs. if selected_backend == _Backend.FLASH_ATTN: if not current_platform.has_device_capability(80): diff --git a/vllm/config.py b/vllm/config.py index 4533fb017188c..b0d6e0374e99e 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -376,9 +376,10 @@ def verify_async_output_proc(self, parallel_config, speculative_config, # Reminder: Please update docs/source/serving/compatibility_matrix.rst # If the feature combo become valid - if device_config.device_type not in ("cuda", "tpu", "xpu"): + if device_config.device_type not in ("cuda", "tpu", "xpu", "hpu"): logger.warning( - "Async output processing is only supported for CUDA, TPU, XPU. " + "Async output processing is only supported for CUDA, TPU, XPU " + "and HPU." "Disabling it for other platforms.") self.use_async_output_proc = False return @@ -774,7 +775,6 @@ class LoadConfig: ignore_patterns: The list of patterns to ignore when loading the model. Default to "original/**/*" to avoid repeated loading of llama's checkpoints. - """ load_format: Union[str, LoadFormat, "BaseModelLoader"] = LoadFormat.AUTO @@ -877,6 +877,13 @@ def __init__( raise ValueError( "TPU backend only supports Ray for distributed inference.") + if current_platform.is_hpu() and self.world_size > 1: + if self.distributed_executor_backend is None: + self.distributed_executor_backend = "ray" + if self.distributed_executor_backend != "ray": + raise ValueError( + "HPU backend only supports Ray for distributed inference.") + if self.distributed_executor_backend is None and self.world_size > 1: # We use multiprocessing by default if world_size fits on the # current node and we aren't in a ray placement group. @@ -1079,6 +1086,8 @@ def __init__(self, device: str = "auto") -> None: self.device_type = "cuda" elif is_neuron(): self.device_type = "neuron" + elif current_platform.is_hpu(): + self.device_type = "hpu" elif is_openvino(): self.device_type = "openvino" elif current_platform.is_tpu(): @@ -1644,6 +1653,13 @@ def _get_and_verify_dtype( torch_dtype = torch.float16 else: torch_dtype = config_dtype + + if current_platform.is_hpu() and config_dtype == torch.float16: + logger.info( + "For HPU, we cast models to bfloat16 instead of" + "using float16 by default. Please specify `dtype` if you " + "want to use float16.") + torch_dtype = torch.bfloat16 else: if dtype not in _STR_DTYPE_TO_TORCH_DTYPE: raise ValueError(f"Unknown dtype: {dtype}") diff --git a/vllm/core/block/cpu_gpu_block_allocator.py b/vllm/core/block/cpu_gpu_block_allocator.py index 6eda5f99aa1c8..9727f6e19b84e 100644 --- a/vllm/core/block/cpu_gpu_block_allocator.py +++ b/vllm/core/block/cpu_gpu_block_allocator.py @@ -4,6 +4,7 @@ DeviceAwareBlockAllocator) from vllm.core.block.naive_block import NaiveBlock, NaiveBlockAllocator from vllm.core.block.prefix_caching_block import PrefixCachingBlockAllocator +from vllm.platforms import current_platform from vllm.utils import Device @@ -52,7 +53,11 @@ def create( - The block IDs are assigned contiguously, with GPU block IDs coming before CPU block IDs. """ - block_ids = list(range(num_gpu_blocks + num_cpu_blocks)) + # For HPU, block id 0 is used only for padding + reserved_blocks = 1 if current_platform.is_hpu() else 0 + block_ids = list( + range(reserved_blocks, num_gpu_blocks + num_cpu_blocks)) + num_gpu_blocks -= reserved_blocks gpu_block_ids = block_ids[:num_gpu_blocks] cpu_block_ids = block_ids[num_gpu_blocks:] diff --git a/vllm/distributed/device_communicators/hpu_communicator.py b/vllm/distributed/device_communicators/hpu_communicator.py new file mode 100644 index 0000000000000..cc9b19ce022b5 --- /dev/null +++ b/vllm/distributed/device_communicators/hpu_communicator.py @@ -0,0 +1,48 @@ +import torch +import torch.distributed as dist +from torch.distributed import ProcessGroup + +from vllm.platforms import current_platform + +if current_platform.is_hpu(): + import habana_frameworks.torch as htorch # noqa: F401 + + +class HpuCommunicator: + + def __init__(self, group: ProcessGroup): + if not current_platform.is_hpu(): + self.disabled = True + return + self.disabled = False + self.group = group + self.world_size = dist.get_world_size(self.group) + + def all_reduce(self, x: torch.Tensor) -> torch.Tensor: + # FIXME(kzawora): this is a workaround for a bug in Habana PT bridge + # occurring when PT_HPU_ENABLE_LAZY_COLLECTIVES=true env var is used + # (which is required for tensor parallel HPUGraph inference) + htorch.core.mark_step() + dist.all_reduce(x, group=self.group) + return x + + def all_gather(self, x: torch.Tensor, dim: int = -1) -> torch.Tensor: + world_size = self.world_size + if dim < 0: + # Convert negative dim to positive. + dim += x.dim() + input_size = x.size() + # Allocate output tensor. + output_tensor = torch.empty((world_size, ) + input_size, + dtype=x.dtype, + device=x.device) + # All-gather. + htorch.core.mark_step() + dist.all_gather_into_tensor(output_tensor, x, group=self.group) + # Reshape + output_tensor = output_tensor.movedim(0, dim) + output_tensor = output_tensor.reshape(input_size[:dim] + + (world_size * + input_size[dim], ) + + input_size[dim + 1:]) + return output_tensor diff --git a/vllm/distributed/parallel_state.py b/vllm/distributed/parallel_state.py index 6e1970bfed98a..e70d50a8a0898 100644 --- a/vllm/distributed/parallel_state.py +++ b/vllm/distributed/parallel_state.py @@ -166,6 +166,7 @@ def __init__( use_pynccl: bool, use_custom_allreduce: bool, use_tpu_communicator: bool, + use_hpu_communicator: bool, use_message_queue_broadcaster: bool = False, group_name: Optional[str] = None, ): @@ -202,6 +203,7 @@ def __init__( self.use_pynccl = use_pynccl self.use_custom_allreduce = use_custom_allreduce self.use_tpu_communicator = use_tpu_communicator + self.use_hpu_communicator = use_hpu_communicator # lazy import to avoid documentation build error from vllm.distributed.device_communicators.custom_all_reduce import ( @@ -230,6 +232,12 @@ def __init__( if use_tpu_communicator and self.world_size > 1: self.tpu_communicator = TpuCommunicator(group=self.cpu_group) + from vllm.distributed.device_communicators.hpu_communicator import ( + HpuCommunicator) + self.hpu_communicator: Optional[HpuCommunicator] + if use_hpu_communicator and self.world_size > 1: + self.hpu_communicator = HpuCommunicator(group=self.device_group) + from vllm.distributed.device_communicators.shm_broadcast import ( MessageQueue) self.mq_broadcaster: Optional[MessageQueue] = None @@ -387,6 +395,11 @@ def all_gather(self, input_: torch.Tensor, dim: int = -1) -> torch.Tensor: if tpu_comm is not None and not tpu_comm.disabled: return tpu_comm.all_gather(input_, dim) + # For HPUs, use HPU communicator. + hpu_comm = self.hpu_communicator + if hpu_comm is not None and not hpu_comm.disabled: + return hpu_comm.all_gather(input_, dim) + if dim < 0: # Convert negative dim to positive. dim += input_.dim() @@ -839,6 +852,7 @@ def init_world_group(ranks: List[int], local_rank: int, use_pynccl=False, use_custom_allreduce=False, use_tpu_communicator=False, + use_hpu_communicator=False, group_name="world", ) @@ -860,6 +874,7 @@ def init_model_parallel_group( use_pynccl=True, use_custom_allreduce=use_custom_allreduce, use_tpu_communicator=True, + use_hpu_communicator=True, use_message_queue_broadcaster=use_message_queue_broadcaster, group_name=group_name, ) diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 41963dcb16922..22c63af5e4e26 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -16,6 +16,7 @@ from vllm.executor.executor_base import ExecutorBase from vllm.logger import init_logger from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS +from vllm.platforms import current_platform from vllm.transformers_utils.utils import check_gguf_file from vllm.utils import FlexibleArgumentParser @@ -34,6 +35,7 @@ "openvino", "tpu", "xpu", + "hpu", ] @@ -104,7 +106,9 @@ class EngineArgs: pipeline_parallel_size: int = 1 tensor_parallel_size: int = 1 max_parallel_loading_workers: Optional[int] = None - block_size: int = 16 + # NOTE(kzawora): default block size for Gaudi should be 128 + # smaller sizes still work, but very inefficiently + block_size: int = 16 if not current_platform.is_hpu() else 128 enable_prefix_caching: bool = False disable_sliding_window: bool = False use_v2_block_manager: bool = True @@ -361,7 +365,7 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: parser.add_argument('--block-size', type=int, default=EngineArgs.block_size, - choices=[8, 16, 32], + choices=[8, 16, 32, 64, 128], help='Token block size for contiguous chunks of ' 'tokens. This is ignored on neuron devices and ' 'set to max-model-len') @@ -1039,8 +1043,7 @@ def create_engine_config(self) -> EngineConfig: multi_step_stream_outputs=self.multi_step_stream_outputs, send_delta_data=(envs.VLLM_USE_RAY_SPMD_WORKER and parallel_config.use_ray), - policy=self.scheduling_policy, - ) + policy=self.scheduling_policy) lora_config = LoRAConfig( max_lora_rank=self.max_lora_rank, max_loras=self.max_loras, diff --git a/vllm/engine/async_llm_engine.py b/vllm/engine/async_llm_engine.py index 1f57aecb6481d..66f0b339e6e00 100644 --- a/vllm/engine/async_llm_engine.py +++ b/vllm/engine/async_llm_engine.py @@ -617,6 +617,14 @@ def _get_executor_cls( elif engine_config.device_config.device_type == "cpu": from vllm.executor.cpu_executor import CPUExecutorAsync executor_class = CPUExecutorAsync + elif engine_config.device_config.device_type == "hpu": + if distributed_executor_backend == "ray": + initialize_ray_cluster(engine_config.parallel_config) + from vllm.executor.ray_hpu_executor import RayHPUExecutorAsync + executor_class = RayHPUExecutorAsync + else: + from vllm.executor.hpu_executor import HPUExecutorAsync + executor_class = HPUExecutorAsync elif engine_config.device_config.device_type == "openvino": assert distributed_executor_backend is None, ( "Distributed execution is not supported with " diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 61c21887e6816..3ebd3ec2e0030 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -524,6 +524,14 @@ def _get_executor_cls(cls, elif engine_config.device_config.device_type == "cpu": from vllm.executor.cpu_executor import CPUExecutor executor_class = CPUExecutor + elif engine_config.device_config.device_type == "hpu": + if distributed_executor_backend == "ray": + initialize_ray_cluster(engine_config.parallel_config) + from vllm.executor.ray_hpu_executor import RayHPUExecutor + executor_class = RayHPUExecutor + else: + from vllm.executor.hpu_executor import HPUExecutor + executor_class = HPUExecutor elif engine_config.device_config.device_type == "openvino": from vllm.executor.openvino_executor import OpenVINOExecutor executor_class = OpenVINOExecutor diff --git a/vllm/executor/hpu_executor.py b/vllm/executor/hpu_executor.py new file mode 100644 index 0000000000000..34879bc4e7ef5 --- /dev/null +++ b/vllm/executor/hpu_executor.py @@ -0,0 +1,211 @@ +############################################################################### +# Copyright (C) 2024 Habana Labs, Ltd. an Intel Company +############################################################################### + +import contextlib +import os +from typing import Any, Dict, List, Optional, Set, Tuple + +from vllm.executor.executor_base import ExecutorAsyncBase, ExecutorBase +from vllm.logger import init_logger +from vllm.lora.request import LoRARequest +from vllm.model_executor.layers.sampler import SamplerOutput +from vllm.prompt_adapter.request import PromptAdapterRequest +from vllm.sequence import ExecuteModelRequest +from vllm.utils import (get_distributed_init_method, get_ip, get_open_port, + make_async) +from vllm.worker.worker_base import WorkerWrapperBase + +logger = init_logger(__name__) + + +class HPUExecutor(ExecutorBase): + + uses_ray: bool = False + + def _init_executor(self) -> None: + """Initialize the worker and load the model.""" + self._init_worker() + + def _get_worker_kwargs( + self, + local_rank: int = 0, + rank: int = 0, + distributed_init_method: Optional[str] = None) -> Dict[str, Any]: + """Return worker init args for a given rank.""" + if distributed_init_method is None: + distributed_init_method = get_distributed_init_method( + get_ip(), get_open_port()) + return dict( + model_config=self.model_config, + parallel_config=self.parallel_config, + scheduler_config=self.scheduler_config, + device_config=self.device_config, + cache_config=self.cache_config, + load_config=self.load_config, + local_rank=local_rank, + rank=rank, + distributed_init_method=distributed_init_method, + lora_config=self.lora_config, + is_driver_worker=rank == 0, + ) + + def _create_worker(self, + local_rank: int = 0, + rank: int = 0, + distributed_init_method: Optional[str] = None): + wrapper = WorkerWrapperBase( + worker_module_name="vllm.worker.hpu_worker", + worker_class_name="HPUWorker", + ) + wrapper.init_worker(**self._get_worker_kwargs(local_rank, rank, + distributed_init_method)) + return wrapper.worker + + def _init_worker(self): + assert self.parallel_config.world_size == 1, ( + "GPUExecutor only supports single GPU.") + + self.driver_worker = self._create_worker() + self.driver_worker.init_device() + self.driver_worker.load_model() + + def determine_num_available_blocks(self) -> Tuple[int, int]: + """Determine the number of available KV blocks by invoking the + underlying worker. + """ + return self.driver_worker.determine_num_available_blocks() + + def initialize_cache(self, num_gpu_blocks: int, num_cpu_blocks) -> None: + """Initialize the KV cache by invoking the underlying worker. + """ + # NOTE: This is logged in the executor because there can be >1 worker + # with other executors. We could log in the engine level, but work + # remains to abstract away the device for non-GPU configurations. + logger.info("# HPU blocks: %d, # CPU blocks: %d", num_gpu_blocks, + num_cpu_blocks) + from vllm_hpu_extension.profiler import HabanaMemoryProfiler + with HabanaMemoryProfiler() as cache_init_m: + self.driver_worker.initialize_cache(num_gpu_blocks, num_cpu_blocks) + msg = f"init_cache_engine took {cache_init_m.get_summary_string()}" + logger.info(msg) + + def finish_measurements(self): + self.driver_worker.finish_measurements() + + def execute_model( + self, + execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]: + # VLLM_HPU_LOG_STEP_GRAPH_COMPILATION - will log graph compilations per engine step, only when there was any - highly recommended to use alongside PT_HPU_METRICS_GC_DETAILS! # noqa:E501 + # VLLM_HPU_LOG_STEP_GRAPH_COMPILATION_ALL - will log graph compilations per engine step, always, even if there were none # noqa:E501 + # VLLM_HPU_LOG_STEP_CPU_FALLBACKS - will log cpu fallbacks per engine step, only when there was any # noqa:E501 + # VLLM_HPU_LOG_STEP_CPU_FALLBACKS_ALL - will log cpu fallbacks per engine step, always, even if there were none # noqa:E501 + log_graph_compilation_all = os.environ.get( + 'VLLM_HPU_LOG_STEP_GRAPH_COMPILATION_ALL', '0') != '0' + log_graph_compilation = os.environ.get( + 'VLLM_HPU_LOG_STEP_GRAPH_COMPILATION', + '0') != '0' or log_graph_compilation_all + log_cpu_fallbacks_all = os.environ.get( + 'VLLM_HPU_LOG_STEP_CPU_FALLBACKS_ALL', '0') != '0' + log_cpu_fallbacks = os.environ.get('VLLM_HPU_LOG_STEP_CPU_FALLBACKS', + '0') != '0' or log_cpu_fallbacks_all + if log_graph_compilation or log_cpu_fallbacks: + from habana_frameworks.torch.hpu.metrics import metric_localcontext + seq_group_metadata_list = execute_model_req.seq_group_metadata_list + is_prompt = any([ + seq_group_metadata.is_prompt + for seq_group_metadata in seq_group_metadata_list + ]) + max_context_len = max([ + max([ + len(v.prompt_token_ids) + len(v.output_token_ids) + for v in seq_group_metadata.seq_data.values() + ]) for seq_group_metadata in seq_group_metadata_list + ]) # whoa, that's some spicy stuff right here + max_num_blocks = ( + (max_context_len - 1) // self.cache_config.block_size) + 1 + input_stats = (f'is_prompt: {is_prompt}, ' + f'num_seqs: {len(seq_group_metadata_list)}, ' + f'max_context_len: {max_context_len}, ' + f'max_num_blocks {max_num_blocks}') + gc_ctx = metric_localcontext( + "graph_compilation" + ) if log_graph_compilation else contextlib.nullcontext() + cpu_fallback_ctx = metric_localcontext( + "cpu_fallback" + ) if log_cpu_fallbacks else contextlib.nullcontext() + with gc_ctx as gc_local_metric, \ + cpu_fallback_ctx as cpu_fallback_local_metric: + output = self.driver_worker.execute_model(execute_model_req) + if (log_graph_compilation and gc_local_metric.stats()[0][1] > 0 + ) or log_graph_compilation_all: + msg = ("VLLM_HPU_STEP_GRAPH_COMPILATION: " + f"{gc_local_metric.stats()}, {input_stats}") + logger.warning(msg) + if (log_cpu_fallbacks and cpu_fallback_local_metric.stats()[0][1] > + 0) or log_cpu_fallbacks_all: + msg = ("VLLM_HPU_STEP_CPU_FALLBACK: " + f"{cpu_fallback_local_metric.stats()}, {input_stats}") + logger.warning(msg) + + return output + + output = self.driver_worker.execute_model(execute_model_req) + return output + + def add_lora(self, lora_request: LoRARequest) -> bool: + assert lora_request.lora_int_id > 0, "lora_id must be greater than 0." + return self.driver_worker.add_lora(lora_request) + + def remove_lora(self, lora_id: int) -> bool: + assert lora_id > 0, "lora_id must be greater than 0." + return self.driver_worker.remove_lora(lora_id) + + def pin_lora(self, lora_id: int) -> bool: + assert lora_id > 0, "lora_id must be greater than 0." + return self.driver_worker.pin_lora(lora_id) + + def list_loras(self) -> Set[int]: + return self.driver_worker.list_loras() + + def add_prompt_adapter( + self, prompt_adapter_request: PromptAdapterRequest) -> bool: + raise NotImplementedError( + "Prompt Adapter is not implemented for HPU backend.") + + def remove_prompt_adapter(self, prompt_adapter_id: int) -> bool: + raise NotImplementedError( + "Prompt Adapter is not implemented for HPU backend.") + + def pin_prompt_adapter(self, prompt_adapter_id: int) -> bool: + raise NotImplementedError( + "Prompt Adapter is not implemented for HPU backend.") + + def list_prompt_adapters(self) -> Set[int]: + raise NotImplementedError( + "Prompt Adapter is not implemented for HPU backend.") + + def check_health(self) -> None: + # GPUExecutor will always be healthy as long as + # it's running. + return + + def start_profile(self) -> None: + self.driver_worker.start_profile() + + def stop_profile(self) -> None: + self.driver_worker.stop_profile() + + def shutdown(self) -> None: + self.driver_worker.shutdown_inc() + + +class HPUExecutorAsync(HPUExecutor, ExecutorAsyncBase): + + async def execute_model_async( + self, + execute_model_req: ExecuteModelRequest, + ) -> List[SamplerOutput]: + output = await make_async(self.driver_worker.execute_model + )(execute_model_req=execute_model_req, ) + return output diff --git a/vllm/executor/ray_hpu_executor.py b/vllm/executor/ray_hpu_executor.py new file mode 100644 index 0000000000000..28d1882cb0db7 --- /dev/null +++ b/vllm/executor/ray_hpu_executor.py @@ -0,0 +1,554 @@ +import asyncio +import os +from collections import defaultdict +from itertools import islice, repeat +from typing import (TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, + Type) + +import msgspec + +import vllm.envs as envs +from vllm.executor.distributed_gpu_executor import ( # yapf: disable + DistributedGPUExecutor, DistributedGPUExecutorAsync) +from vllm.executor.msgspec_utils import encode_hook +from vllm.executor.ray_utils import RayWorkerWrapper, ray +from vllm.logger import init_logger +from vllm.model_executor.layers.sampler import SamplerOutput +from vllm.sequence import ExecuteModelRequest +from vllm.utils import (_run_task_with_lock, get_distributed_init_method, + get_ip, get_open_port, get_vllm_instance_id, + make_async) +from vllm.worker.worker_base import WorkerBase + +if ray is not None: + from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy + +if TYPE_CHECKING: + from ray.util.placement_group import PlacementGroup + +logger = init_logger(__name__) + + +class RayHPUExecutor(DistributedGPUExecutor): + + uses_ray: bool = True + + def _init_executor(self) -> None: + self.forward_dag: Optional["ray.dag.CompiledDAG"] = None + # If the env var is set, it uses the Ray's compiled DAG API + # which optimizes the control plane overhead. + # Run vLLM with VLLM_USE_RAY_COMPILED_DAG=1 to enable it. + # Currently, this requires USE_RAY_SPMD_WORKER=True. + self.use_ray_compiled_dag = envs.VLLM_USE_RAY_COMPILED_DAG + # If the env var is set, then we do not distinguish between the + # "driver worker" vs other workers. Also, the rank 0 worker will + # be executed in a remote Ray worker. Currently this requires + # USE_RAY_COMPILED_DAG=True. + self.use_ray_spmd_worker = envs.VLLM_USE_RAY_SPMD_WORKER + if self.use_ray_compiled_dag: + assert self.use_ray_spmd_worker, ( + "VLLM_USE_RAY_COMPILED_DAG=1 requires " + "VLLM_USE_RAY_SPMD_WORKER=1") + if self.use_ray_spmd_worker: + # TODO: Support SPMD worker for non-DAG Ray executor. + assert self.use_ray_compiled_dag, ( + "VLLM_USE_RAY_SPMD_WORKER=1 requires " + "VLLM_USE_RAY_COMPILED_DAG=1") + + assert self.uses_ray + placement_group = self.parallel_config.placement_group + + # Disable Ray usage stats collection. + ray_usage = os.environ.get("RAY_USAGE_STATS_ENABLED", "0") + if ray_usage != "1": + os.environ["RAY_USAGE_STATS_ENABLED"] = "0" + + # Create the parallel GPU workers. + self._init_workers_ray(placement_group) + + self.input_encoder = msgspec.msgpack.Encoder(enc_hook=encode_hook) + self.output_decoder = msgspec.msgpack.Decoder( + Optional[List[SamplerOutput]]) + + def shutdown(self) -> None: + if hasattr(self, "forward_dag") and self.forward_dag is not None: + self.forward_dag.teardown() + import ray + for worker in self.workers: + ray.kill(worker) + self.forward_dag = None + + def finish_measurements(self): + self._run_workers("finish_measurements") + + def _get_worker_module_and_class( + self + ) -> Tuple[str, str, Optional[Callable[[], + Type[WorkerBase]]]]: # noqa: F821 + worker_class_fn = None + if self.scheduler_config.is_multi_step: + raise NotImplementedError( + "Multi-step execution is not implemented for HPU") + elif self.speculative_config: + raise NotImplementedError( + "Speculative decoding is not implemented for HPU") + else: + worker_module_name = "vllm.worker.hpu_worker" + worker_class_name = "HPUWorker" + return (worker_module_name, worker_class_name, worker_class_fn) + + def _get_worker_wrapper_args(self) -> Dict[str, Any]: + (worker_module_name, worker_class_name, + worker_class_fn) = self._get_worker_module_and_class() + + return dict( + worker_module_name=worker_module_name, + worker_class_name=worker_class_name, + worker_class_fn=worker_class_fn, + trust_remote_code=self.model_config.trust_remote_code, + ) + + def _init_workers_ray(self, placement_group: "PlacementGroup", + **ray_remote_kwargs): + # Otherwise, the ray workers are allocated with a full GPU. + num_gpus = 1 + + # The driver dummy worker does not actually use any resources. + # It holds the resource for the driver worker. + self.driver_dummy_worker: Optional[RayWorkerWrapper] = None + # The remaining workers are the actual ray actors. + self.workers: List[RayWorkerWrapper] = [] + + # Used in ray compiled DAG: indexed first by PP rank, + # and then TP rank. In other words, the inner list is + # the TP group of workers for a PP rank. + self.pp_tp_workers: List[List[RayWorkerWrapper]] = [] + + logger.info("use_ray_spmd_worker: %s", self.use_ray_spmd_worker) + + # Create the workers. + driver_ip = get_ip() + worker_wrapper_kwargs = self._get_worker_wrapper_args() + for bundle_id, bundle in enumerate(placement_group.bundle_specs): + if not bundle.get("HPU", 0): + continue + scheduling_strategy = PlacementGroupSchedulingStrategy( + placement_group=placement_group, + placement_group_capture_child_tasks=True, + placement_group_bundle_index=bundle_id, + ) + + worker = ray.remote( + num_cpus=0, + num_gpus=0, + resources={'HPU': num_gpus}, + scheduling_strategy=scheduling_strategy, + **ray_remote_kwargs, + )(RayWorkerWrapper).remote(**worker_wrapper_kwargs) + + if self.use_ray_spmd_worker: + self.workers.append(worker) + else: + worker_ip = ray.get(worker.get_node_ip.remote()) + if worker_ip == driver_ip and self.driver_dummy_worker is None: + # If the worker is on the same node as the driver, we use it + # as the resource holder for the driver process. + self.driver_dummy_worker = worker + self.driver_worker = RayWorkerWrapper( + **worker_wrapper_kwargs) + else: + # Else, added to the list of workers. + self.workers.append(worker) + + logger.debug("workers: %s", self.workers) + logger.debug("driver_dummy_worker: %s", self.driver_dummy_worker) + if not self.use_ray_spmd_worker and self.driver_dummy_worker is None: + raise ValueError( + "Ray does not allocate any GPUs on the driver node. Consider " + "adjusting the Ray placement group or running the driver on a " + "GPU node.") + + worker_ips = [ + ray.get(worker.get_node_ip.remote()) # type: ignore[attr-defined] + for worker in self.workers + ] + ip_counts: Dict[str, int] = {} + for ip in worker_ips: + ip_counts[ip] = ip_counts.get(ip, 0) + 1 + + def sort_by_driver_then_worker_ip(worker): + """ + Sort the workers based on 3 properties: + 1. If the worker is on the same node as the driver (vllm engine), + it should be placed first. + 2. Then, if the worker is on a node with fewer workers, it should + be placed first. + 3. Finally, if the work is on a node with smaller IP address, it + should be placed first. + """ + ip = ray.get(worker.get_node_ip.remote()) + return (ip != driver_ip, ip_counts[ip], ip) + + # After sorting, the workers on the same node will be + # close to each other, and the workers on the driver + # node will be placed first. + self.workers = sorted(self.workers, key=sort_by_driver_then_worker_ip) + + # Get the set of GPU IDs used on each node. + worker_node_and_gpu_ids = self._run_workers("get_node_and_gpu_ids", + use_dummy_driver=True) + + node_workers = defaultdict(list) # node id -> list of worker ranks + node_gpus = defaultdict(list) # node id -> list of gpu ids + + for i, (node_id, gpu_ids) in enumerate(worker_node_and_gpu_ids): + node_workers[node_id].append(i) + # `gpu_ids` can be a list of strings or integers. + # convert them to integers for consistency. + # NOTE: gpu_ids can be larger than 9 (e.g. 16 GPUs), + # string sorting is not sufficient. + # see https://github.com/vllm-project/vllm/issues/5590 + gpu_ids = [int(x) for x in gpu_ids] + node_gpus[node_id].extend(gpu_ids) + for node_id, gpu_ids in node_gpus.items(): + node_gpus[node_id] = sorted(gpu_ids) + + all_ips = set(worker_ips + [driver_ip]) + n_ips = len(all_ips) + n_nodes = len(node_workers) + + if n_nodes != n_ips: + raise RuntimeError( + f"Every node should have a unique IP address. Got {n_nodes}" + f" nodes with node ids {list(node_workers.keys())} and " + f"{n_ips} unique IP addresses {all_ips}. Please check your" + " network configuration. If you set `VLLM_HOST_IP` or " + "`HOST_IP` environment variable, make sure it is unique for" + " each node.") + + VLLM_INSTANCE_ID = get_vllm_instance_id() + + # Set environment variables for the driver and workers. + all_args_to_update_environment_variables = [({ + "VLLM_INSTANCE_ID": + VLLM_INSTANCE_ID, + "VLLM_TRACE_FUNCTION": + str(envs.VLLM_TRACE_FUNCTION), + }, ) for (node_id, _) in worker_node_and_gpu_ids] + self._run_workers("update_environment_variables", + all_args=all_args_to_update_environment_variables) + + if len(node_gpus) == 1: + # in single node case, we don't need to get the IP address. + # the loopback address is sufficient + # NOTE: a node may have several IP addresses, one for each + # network interface. `get_ip()` might return any of them, + # while they might not work for communication inside the node + # if the network setup is complicated. Using the loopback address + # solves this issue, as it always works for communication inside + # the node. + driver_ip = "127.0.0.1" + distributed_init_method = get_distributed_init_method( + driver_ip, get_open_port()) + + # Initialize the actual workers inside worker wrapper. + init_worker_all_kwargs = [ + self._get_worker_kwargs( + local_rank=node_workers[node_id].index(rank), + rank=rank, + distributed_init_method=distributed_init_method, + ) for rank, (node_id, _) in enumerate(worker_node_and_gpu_ids) + ] + self._run_workers("init_worker", all_kwargs=init_worker_all_kwargs) + + self._run_workers("init_device") + self._run_workers("load_model", + max_concurrent_workers=self.parallel_config. + max_parallel_loading_workers) + + if self.use_ray_spmd_worker: + for pp_rank in range(self.parallel_config.pipeline_parallel_size): + self.pp_tp_workers.append([]) + for tp_rank in range( + self.parallel_config.tensor_parallel_size): + # PP=2, TP=4 + # pp_tp_workers = [[0, 1, 2, 3], [4, 5, 6, 7]] + rank = (pp_rank * self.parallel_config.tensor_parallel_size + ) + tp_rank + assert len(self.pp_tp_workers[pp_rank]) == tp_rank + assert pp_rank < len(self.pp_tp_workers) + self.pp_tp_workers[pp_rank].append(self.workers[rank]) + + # This is the list of workers that are rank 0 of each TP group EXCEPT + # global rank 0. These are the workers that will broadcast to the + # rest of the workers. + self.tp_driver_workers: List[RayWorkerWrapper] = [] + # This is the list of workers that are not drivers and not the first + # worker in a TP group. These are the workers that will be + # broadcasted to. + self.non_driver_workers: List[RayWorkerWrapper] = [] + + # Enforce rank order for correct rank to return final output. + for index, worker in enumerate(self.workers): + # The driver worker is rank 0 and not in self.workers. + rank = index + 1 + if rank % self.parallel_config.tensor_parallel_size == 0: + self.tp_driver_workers.append(worker) + else: + self.non_driver_workers.append(worker) + + def _driver_execute_model( + self, execute_model_req: Optional[ExecuteModelRequest] + ) -> Optional[List[SamplerOutput]]: + """Run execute_model in the driver worker. + + Passing None will cause the driver to stop the model execution + loop running in each of the remote workers. + """ + assert not self.use_ray_spmd_worker, ( + "driver_worker does not exist for VLLM_USE_RAY_SPMD_WORKER=1") + return self.driver_worker.execute_method("execute_model", + execute_model_req) + + def execute_model( + self, + execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]: + if not self.use_ray_spmd_worker: + return super().execute_model(execute_model_req) + + if self.forward_dag is None: + self.forward_dag = self._compiled_ray_dag(enable_asyncio=False) + + serialized_data = self.input_encoder.encode(execute_model_req) + outputs = ray.get(self.forward_dag.execute(serialized_data)) + output = self.output_decoder.decode(outputs[0]) + return output + + def _run_workers( + self, + method: str, + *args, + async_run_tensor_parallel_workers_only: bool = False, + all_args: Optional[List[Tuple[Any, ...]]] = None, + all_kwargs: Optional[List[Dict[str, Any]]] = None, + use_dummy_driver: bool = False, + max_concurrent_workers: Optional[int] = None, + **kwargs, + ) -> Any: + """Runs the given method on all workers. Can be used in the following + ways: + + Args: + - async_run_tensor_parallel_workers_only: If True the method will be + run only in the remote TP workers, not the driver worker. + It will also be run asynchronously and return a list of futures + rather than blocking on the results. + - args/kwargs: All workers share the same args/kwargs + - all_args/all_kwargs: args/kwargs for each worker are specified + individually + """ + if self.use_ray_spmd_worker: + assert not async_run_tensor_parallel_workers_only, ( + "async_run_tensor_parallel_workers_only is not supported for " + "spmd mode.") + + if max_concurrent_workers: + raise NotImplementedError( + "max_concurrent_workers is not supported yet.") + + count = len(self.workers) if not \ + async_run_tensor_parallel_workers_only \ + else len(self.non_driver_workers) + # If using SPMD worker, all workers are the same, so we should execute + # the args on all workers. Otherwise, we skip the first worker's args + # because those args will go to the driver worker. + first_worker_args_index: int = 0 if self.use_ray_spmd_worker else 1 + all_worker_args = repeat(args, count) if all_args is None \ + else islice(all_args, first_worker_args_index, None) + all_worker_kwargs = repeat(kwargs, count) if all_kwargs is None \ + else islice(all_kwargs, first_worker_args_index, None) + + # Start the ray workers first. + ray_workers = self.workers + if async_run_tensor_parallel_workers_only: + ray_workers = self.non_driver_workers + ray_worker_outputs = [ + worker.execute_method.remote(method, *worker_args, **worker_kwargs) + for (worker, worker_args, worker_kwargs + ) in zip(ray_workers, all_worker_args, all_worker_kwargs) + ] + + if async_run_tensor_parallel_workers_only: + # Just return futures + return ray_worker_outputs + + driver_worker_output = [] + # In SPMD mode, the driver worker is the same as any other worker, + # so we only explicitly execute on the driver worker if using a + # non-SPMD worker class. + if not self.use_ray_spmd_worker: + driver_args = args if all_args is None else all_args[0] + driver_kwargs = kwargs if all_kwargs is None else all_kwargs[0] + + # Start the driver worker after all the ray workers. + if not use_dummy_driver: + driver_worker_output = [ + self.driver_worker.execute_method(method, *driver_args, + **driver_kwargs) + ] + else: + assert self.driver_dummy_worker is not None + driver_worker_output = [ + ray.get( + self.driver_dummy_worker.execute_method.remote( + method, *driver_args, **driver_kwargs)) + ] + + # Get the results of the ray workers. + if self.workers: + ray_worker_outputs = ray.get(ray_worker_outputs) + + return driver_worker_output + ray_worker_outputs + + def _wait_for_tasks_completion(self, parallel_worker_tasks: Any) -> None: + """Wait for futures returned from _run_workers() with + async_run_remote_workers_only to complete.""" + ray.get(parallel_worker_tasks) + + def _check_ray_adag_installation(self): + import pkg_resources + from packaging import version + + required_version = version.parse("2.35") + current_version = version.parse( + pkg_resources.get_distribution("ray").version) + # TODO: update the constraint once we adapt to the backward + # incompatible API change from ray 2.36 + if current_version != required_version: + raise ValueError(f"Ray version {required_version} is " + f"required, but found {current_version}") + + import importlib.util + adag_spec = importlib.util.find_spec( + "ray.experimental.compiled_dag_ref") + if adag_spec is None: + raise ValueError("Ray accelerated DAG is not installed. " + "Run `pip install ray[adag]` to install it.") + + def _compiled_ray_dag(self, enable_asyncio: bool): + assert self.parallel_config.use_ray + self._check_ray_adag_installation() + from ray.dag import InputNode, MultiOutputNode + from ray.experimental.channel.torch_tensor_type import TorchTensorType + + with InputNode() as input_data: + # Example DAG: PP=2, TP=4 + # (ExecuteModelReq, None) -> 0 -> (ExecuteModelReq, IntermediateOutput) -> 4 -> SamplerOutput # noqa: E501 + # -> 1 -> (ExecuteModelReq, IntermediateOutput) -> 5 -> SamplerOutput # noqa: E501 + # -> 2 -> (ExecuteModelReq, IntermediateOutput) -> 6 -> SamplerOutput # noqa: E501 + # -> 3 -> (ExecuteModelReq, IntermediateOutput) -> 7 -> SamplerOutput # noqa: E501 + + # All workers in the first TP group will take in the + # ExecuteModelRequest as input. + outputs = [input_data for _ in self.pp_tp_workers[0]] + for pp_rank, tp_group in enumerate(self.pp_tp_workers): + # Each PP worker takes in the output of the previous PP worker, + # and the TP group executes in SPMD fashion. + outputs = [ + worker.execute_model_spmd. + bind( # type: ignore[attr-defined] + outputs[i]) for i, worker in enumerate(tp_group) + ] + + last_pp_rank = len(self.pp_tp_workers) - 1 + if pp_rank < last_pp_rank: + # Specify how intermediate tensors should be passed + # between pp stages, no need to specify for the last + # pp stage. + transport = "auto" + outputs = [ + output.with_type_hint( + TorchTensorType(transport=transport)) + for output in outputs + ] + + forward_dag = MultiOutputNode(outputs) + + return forward_dag.experimental_compile(enable_asyncio=enable_asyncio) + + def __del__(self): + self.shutdown() + + +class RayHPUExecutorAsync(RayHPUExecutor, DistributedGPUExecutorAsync): + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.pp_locks: Optional[List[asyncio.Lock]] = None + self.use_ray_spmd_worker = envs.VLLM_USE_RAY_SPMD_WORKER + if not self.use_ray_compiled_dag: + self.driver_exec_method = make_async( + self.driver_worker.execute_method) + + async def execute_model_async( + self, + execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]: + if not self.use_ray_spmd_worker: + return await super().execute_model_async(execute_model_req) + + if self.forward_dag is None: + self.forward_dag = self._compiled_ray_dag(enable_asyncio=True) + + serialized_data = self.input_encoder.encode(execute_model_req) + dag_future = await self.forward_dag.execute_async(serialized_data) + outputs = await dag_future + return self.output_decoder.decode(outputs[0]) + + async def _driver_execute_model_async( + self, + execute_model_req: Optional[ExecuteModelRequest] = None + ) -> List[SamplerOutput]: + assert not self.use_ray_spmd_worker, ( + "driver_worker does not exist for VLLM_USE_RAY_SPMD_WORKER=1") + if not self.tp_driver_workers: + return await self.driver_exec_method("execute_model", + execute_model_req) + if self.pp_locks is None: + # This locks each pipeline parallel stage so multiple virtual + # engines can't execute on the same stage at the same time + # We create the locks here to avoid creating them in the constructor + # which uses a different asyncio loop. + self.pp_locks = [ + asyncio.Lock() + for _ in range(self.parallel_config.pipeline_parallel_size) + ] + + tasks = [ + asyncio.create_task( + _run_task_with_lock(self.driver_exec_method, self.pp_locks[0], + "execute_model", execute_model_req)) + ] + for pp_rank, driver_worker in enumerate(self.tp_driver_workers, + start=1): + tasks.append( + asyncio.create_task( + _run_task_with_lock(driver_worker.execute_method.remote, + self.pp_locks[pp_rank], + "execute_model", execute_model_req))) + + results = await asyncio.gather(*tasks) + + # Only the last PP stage has the final results. + return results[-1] + + async def _start_worker_execution_loop(self): + assert not self.use_ray_spmd_worker, ( + "worker loop is disabled for VLLM_USE_RAY_SPMD_WORKER=1") + coros = [ + worker.execute_method.remote("start_worker_execution_loop") + for worker in self.non_driver_workers + ] + return await asyncio.gather(*coros) + + def __del__(self): + self.shutdown() diff --git a/vllm/executor/ray_utils.py b/vllm/executor/ray_utils.py index 7e46acefc5b0e..e709c35c87ef3 100644 --- a/vllm/executor/ray_utils.py +++ b/vllm/executor/ray_utils.py @@ -242,7 +242,11 @@ def initialize_ray_cluster( # Placement group is already set. return - device_str = "GPU" if not current_platform.is_tpu() else "TPU" + device_str = "GPU" + if current_platform.is_tpu(): + device_str = "TPU" + elif current_platform.is_hpu(): + device_str = 'HPU' # Create placement group for worker processes current_placement_group = ray.util.get_current_placement_group() if current_placement_group: diff --git a/vllm/model_executor/custom_op.py b/vllm/model_executor/custom_op.py index d0e90245ad010..ce7defadd227f 100644 --- a/vllm/model_executor/custom_op.py +++ b/vllm/model_executor/custom_op.py @@ -46,10 +46,9 @@ def forward_tpu(self, *args, **kwargs): # NOTE(woosuk): This is a placeholder for future extensions. return self.forward_native(*args, **kwargs) - def forward_gaudi(self, *args, **kwargs): + def forward_hpu(self, *args, **kwargs): # By default, we assume that Gaudi ops are compatible with the # PyTorch-native implementation. - # NOTE(woosuk): This is a placeholder for future extensions. return self.forward_native(*args, **kwargs) def dispatch_forward(self): @@ -63,6 +62,8 @@ def dispatch_forward(self): return self.forward_hip elif is_cpu(): return self.forward_cpu + elif current_platform.is_hpu(): + return self.forward_hpu elif current_platform.is_tpu(): return self.forward_tpu elif is_xpu(): diff --git a/vllm/model_executor/layers/layernorm.py b/vllm/model_executor/layers/layernorm.py index d55f86056d17c..7b0021829cebe 100644 --- a/vllm/model_executor/layers/layernorm.py +++ b/vllm/model_executor/layers/layernorm.py @@ -92,6 +92,25 @@ def forward_cuda( ) return out + def forward_hpu( + self, + x: torch.Tensor, + residual: Optional[torch.Tensor] = None, + ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: + from vllm_hpu_extension.ops import HPUFusedRMSNorm + if HPUFusedRMSNorm is None: + return self.forward_native(x, residual) + if residual is not None: + orig_shape = x.shape + residual += x.view(residual.shape) + # Note: HPUFusedRMSNorm requires 3D tensors as inputs + x = HPUFusedRMSNorm.apply(residual, self.weight, + self.variance_epsilon) + return x.view(orig_shape), residual + + x = HPUFusedRMSNorm.apply(x, self.weight, self.variance_epsilon) + return x + def forward_xpu( self, x: torch.Tensor, diff --git a/vllm/model_executor/layers/logits_processor.py b/vllm/model_executor/layers/logits_processor.py index 1d5b6fad2e160..bdfef67963e93 100644 --- a/vllm/model_executor/layers/logits_processor.py +++ b/vllm/model_executor/layers/logits_processor.py @@ -109,8 +109,14 @@ def _prune_hidden_states( hidden_states: torch.Tensor, sampling_metadata: SamplingMetadata, ) -> torch.Tensor: - return hidden_states.index_select(0, - sampling_metadata.selected_token_indices) + # NOTE(kzawora): The if guard is needed for Gaudi - in some scenarios + # (warmup, profile_run) we might not have selected_token_indices, + # so we skip pruning. + if sampling_metadata.selected_token_indices is not None: + return hidden_states.index_select( + 0, sampling_metadata.selected_token_indices) + else: + return hidden_states def _apply_logits_processors( diff --git a/vllm/model_executor/layers/rotary_embedding.py b/vllm/model_executor/layers/rotary_embedding.py index 2ed44e2093bbe..ca790d966c6a0 100644 --- a/vllm/model_executor/layers/rotary_embedding.py +++ b/vllm/model_executor/layers/rotary_embedding.py @@ -194,6 +194,61 @@ def forward_xpu( self.cos_sin_cache, self.is_neox_style) return query, key + def forward_hpu( + self, + positions: torch.Tensor, + query: torch.Tensor, + key: torch.Tensor, + offsets: Optional[torch.Tensor] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + from habana_frameworks.torch.hpex.kernels import ( + RotaryPosEmbeddingMode, apply_rotary_pos_emb) + positions = positions.flatten() + if offsets is not None: + positions = positions + offsets + num_tokens = positions.shape[0] + cos_sin = self.cos_sin_cache.index_select(0, positions).view( + num_tokens, 1, -1) + cos, sin = cos_sin.chunk(2, dim=-1) + # HPU RoPE kernel requires hidden dimension for cos and sin to be equal + # to query hidden dimension, so the original tensors need to be + # expanded + # GPT-NeoX kernel requires position_ids = None, offset, mode = BLOCKWISE + # and expansion of cos/sin tensors via concatenation + # GPT-J kernel requires position_ids = None, offset = 0, mode = PAIRWISE + # and expansion of cos/sin tensors via repeat_interleave + rope_mode: RotaryPosEmbeddingMode + if self.is_neox_style: + rope_mode = RotaryPosEmbeddingMode.BLOCKWISE + cos = torch.cat((cos, cos), dim=-1) + sin = torch.cat((sin, sin), dim=-1) + else: + rope_mode = RotaryPosEmbeddingMode.PAIRWISE + sin = torch.repeat_interleave(sin, + 2, + dim=-1, + output_size=cos_sin.shape[-1]) + cos = torch.repeat_interleave(cos, + 2, + dim=-1, + output_size=cos_sin.shape[-1]) + + query_shape = query.shape + query = query.view(num_tokens, -1, self.head_size) + query_rot = query[..., :self.rotary_dim] + query_pass = query[..., self.rotary_dim:] + query_rot = apply_rotary_pos_emb(query_rot, cos, sin, None, 0, + rope_mode) + query = torch.cat((query_rot, query_pass), dim=-1).reshape(query_shape) + + key_shape = key.shape + key = key.view(num_tokens, -1, self.head_size) + key_rot = key[..., :self.rotary_dim] + key_pass = key[..., self.rotary_dim:] + key_rot = apply_rotary_pos_emb(key_rot, cos, sin, None, 0, rope_mode) + key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape) + return query, key + def extra_repr(self) -> str: s = f"head_size={self.head_size}, rotary_dim={self.rotary_dim}" s += f", max_position_embeddings={self.max_position_embeddings}" diff --git a/vllm/model_executor/layers/vocab_parallel_embedding.py b/vllm/model_executor/layers/vocab_parallel_embedding.py index b448557af13b3..52771f50a7a23 100644 --- a/vllm/model_executor/layers/vocab_parallel_embedding.py +++ b/vllm/model_executor/layers/vocab_parallel_embedding.py @@ -12,6 +12,7 @@ QuantizationConfig, QuantizeMethodBase, method_has_implemented_embedding) from vllm.model_executor.parameter import BasevLLMParameter from vllm.model_executor.utils import set_weight_attrs +from vllm.platforms import current_platform DEFAULT_VOCAB_PADDING_SIZE = 64 @@ -382,8 +383,20 @@ def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor): # Copy the data. loaded_weight = loaded_weight.narrow(output_dim, start_idx, shard_size) - param[:loaded_weight.shape[0]].data.copy_(loaded_weight) - param[loaded_weight.shape[0]:].data.fill_(0) + + if current_platform.is_hpu(): + # FIXME(kzawora): Weight copy with slicing bugs out on Gaudi here, + # so we're using a workaround. Remove this when fixed in + # HPU PT bridge. + padded_weight = torch.cat([ + loaded_weight, + torch.zeros(param.shape[0] - loaded_weight.shape[0], + *loaded_weight.shape[1:]) + ]) + param.data.copy_(padded_weight) + else: + param[:loaded_weight.shape[0]].data.copy_(loaded_weight) + param[loaded_weight.shape[0]:].data.fill_(0) def forward(self, input_): if self.tp_size > 1: diff --git a/vllm/model_executor/sampling_metadata.py b/vllm/model_executor/sampling_metadata.py index ee02368bec8a8..84f35f75a0c32 100644 --- a/vllm/model_executor/sampling_metadata.py +++ b/vllm/model_executor/sampling_metadata.py @@ -284,7 +284,8 @@ def _prepare_seq_groups( else: # Decode prompt_logprob_len = 0 - query_len = query_lens[i] if query_lens is not None else 1 + query_len = query_lens[i] if query_lens is not None and len( + query_lens) > 0 else 1 sample_len = len(seq_ids) * query_len if do_sample else 0 if sampling_params.seed is not None and generators is not None: diff --git a/vllm/platforms/__init__.py b/vllm/platforms/__init__.py index c648862b2d757..f8f782966769c 100644 --- a/vllm/platforms/__init__.py +++ b/vllm/platforms/__init__.py @@ -42,6 +42,13 @@ except Exception: pass +is_hpu = False +try: + from importlib import util + is_hpu = util.find_spec('habana_frameworks') is not None +except Exception: + pass + is_xpu = False try: @@ -69,6 +76,9 @@ elif is_rocm: from .rocm import RocmPlatform current_platform = RocmPlatform() +elif is_hpu: + from .hpu import HpuPlatform + current_platform = HpuPlatform() elif is_xpu: from .xpu import XPUPlatform current_platform = XPUPlatform() diff --git a/vllm/platforms/hpu.py b/vllm/platforms/hpu.py new file mode 100644 index 0000000000000..170cfff94f90d --- /dev/null +++ b/vllm/platforms/hpu.py @@ -0,0 +1,11 @@ +import torch + +from .interface import Platform, PlatformEnum + + +class HpuPlatform(Platform): + _enum = PlatformEnum.HPU + + @staticmethod + def inference_mode(): + return torch.no_grad() diff --git a/vllm/platforms/interface.py b/vllm/platforms/interface.py index 00742a290e42a..3e79d2e9f0bd6 100644 --- a/vllm/platforms/interface.py +++ b/vllm/platforms/interface.py @@ -8,6 +8,7 @@ class PlatformEnum(enum.Enum): CUDA = enum.auto() ROCM = enum.auto() TPU = enum.auto() + HPU = enum.auto() XPU = enum.auto() CPU = enum.auto() UNSPECIFIED = enum.auto() @@ -42,6 +43,9 @@ def is_rocm(self) -> bool: def is_tpu(self) -> bool: return self._enum == PlatformEnum.TPU + def is_hpu(self) -> bool: + return self._enum == PlatformEnum.HPU + def is_xpu(self) -> bool: return self._enum == PlatformEnum.XPU diff --git a/vllm/utils.py b/vllm/utils.py index 8debae52b288c..ae4c6d60b81fe 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -398,6 +398,9 @@ def seed_everything(seed: int) -> None: if is_xpu(): torch.xpu.manual_seed_all(seed) + if current_platform.is_hpu(): + torch.hpu.manual_seed_all(seed) + def random_uuid() -> str: return str(uuid.uuid4().hex) @@ -778,6 +781,9 @@ def is_pin_memory_available() -> bool: elif is_neuron(): print_warning_once("Pin memory is not supported on Neuron.") return False + elif current_platform.is_hpu(): + print_warning_once("Pin memory is not supported on HPU.") + return False elif is_cpu() or is_openvino(): return False return True diff --git a/vllm/worker/hpu_model_runner.py b/vllm/worker/hpu_model_runner.py new file mode 100644 index 0000000000000..ed71ba8f853a7 --- /dev/null +++ b/vllm/worker/hpu_model_runner.py @@ -0,0 +1,2033 @@ +############################################################################### +# Copyright (C) 2024 Habana Labs, Ltd. an Intel Company +############################################################################### + +import collections +import contextlib +import dataclasses +import functools +import gc +import itertools +import math +import operator +import os +import time +from array import array +from dataclasses import dataclass, field +from enum import IntEnum +from typing import (TYPE_CHECKING, Any, Callable, Dict, List, NamedTuple, + Optional, Set, Tuple, Type, TypeVar, Union) + +import habana_frameworks.torch as htorch +import habana_frameworks.torch.internal.bridge_config as bc +import torch +from vllm_hpu_extension.ops import LoraMask as LoraMask +from vllm_hpu_extension.profiler import (HabanaHighLevelProfiler, + HabanaMemoryProfiler, format_bytes) + +from vllm.attention import AttentionMetadata, get_attn_backend +from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, + ModelConfig, ObservabilityConfig, ParallelConfig, + PromptAdapterConfig, SchedulerConfig) +from vllm.distributed.parallel_state import get_world_group +from vllm.logger import init_logger +from vllm.lora.layers import LoRAMapping +from vllm.lora.request import LoRARequest +from vllm.lora.worker_manager import LRUCacheWorkerLoRAManager +from vllm.model_executor import SamplingMetadata +from vllm.model_executor.layers.sampler import SamplerOutput +from vllm.model_executor.model_loader import get_model +from vllm.multimodal import (MULTIMODAL_REGISTRY, BatchedTensorInputs, + MultiModalInputs) +from vllm.sampling_params import SamplingParams +from vllm.sequence import (IntermediateTensors, SequenceData, + SequenceGroupMetadata) +from vllm.utils import is_pin_memory_available, make_tensor_with_pad +from vllm.worker.model_runner_base import ( + ModelRunnerBase, ModelRunnerInputBase, + _add_attn_metadata_broadcastable_dict, + _add_sampling_metadata_broadcastable_dict, + _init_attn_metadata_from_tensor_dict, + _init_sampling_metadata_from_tensor_dict) + +if TYPE_CHECKING: + from vllm.attention.backends.abstract import AttentionBackend + +logger = init_logger(__name__) + +_TYPE_CACHE = {} +# These values are assumed to be zero in several places. +# Use caution when updating them! +_PAD_SLOT_ID = 0 +_PAD_BLOCK_ID = 0 + +LORA_WARMUP_RANK = 8 + + +class Singleton(type): + _instances: Dict[type, object] = {} + + def __call__(cls, *args, **kwargs): + if cls not in cls._instances: + cls._instances[cls] = super(Singleton, + cls).__call__(*args, **kwargs) + return cls._instances[cls] + + +@dataclass +class HPUBucketingGlobalState(metaclass=Singleton): + prompt_bs_bucket_cfg: Tuple[int, int, int] = field(init=False) + decode_bs_bucket_cfg: Tuple[int, int, int] = field(init=False) + prompt_seq_bucket_cfg: Tuple[int, int, int] = field(init=False) + decode_block_bucket_cfg: Tuple[int, int, int] = field(init=False) + prompt_buckets: List[Tuple[int, int]] = field(init=False) + decode_buckets: List[Tuple[int, int]] = field(init=False) + + +def subtuple(obj: object, + typename: str, + to_copy: List[str], + to_override: Optional[Dict[str, object]] = None): + if obj is None: + return None + if to_override is None: + to_override = {} + fields = set(to_copy) | set(to_override.keys()) + values = {f: to_override.get(f, getattr(obj, f)) for f in fields} + if typename not in _TYPE_CACHE: + _TYPE_CACHE[typename] = collections.namedtuple(typename, + ' '.join(fields)) + return _TYPE_CACHE[typename](**values) + + +def read_bucket_settings(phase: str, dim: str, **defaults): + """Read bucketing configuration from env variables. + + phase is either 'prompt' or 'decode' + dim is either 'bs', 'seq' or 'block' + param is either 'min', 'step' or 'max' + example env variable: VLLM_DECODE_BS_BUCKET_STEP=128 + """ + params = ['min', 'step', 'max'] + env_vars = [f'VLLM_{phase}_{dim}_BUCKET_{p}'.upper() for p in params] + default_values = [defaults[p] for p in params] + values = [ + int(os.environ.get(e, d)) for e, d in zip(env_vars, default_values) + ] + for e, v, d in zip(env_vars, values, default_values): + logger.info('%s=%s (default:%s)', e, v, d) + return values + + +def warmup_range(config: Tuple[int, int, int]): + """Generate a warmup range. + + Start from bmin and multiply by 2 until you reach bstep. + Then, increase the values in the range by the value of bstep until you + reach bmax. + + Example: + bmin = 2, bstep = 32, bmax = 64 + => ramp_up = (2, 4, 8, 16) + => stable = (32, 64) + => return ramp_up + stable => (2, 4, 8, 16, 32, 64) + """ + bmin, bstep, bmax = config + assert bmin <= bmax, ("Min. batch size cannot be greater than max. " + "batch size. If you want to skip warmup, " + "set VLLM_SKIP_WARMUP=true") + base = itertools.repeat(2) + ramp_up_acc = itertools.accumulate(base, func=operator.mul, initial=bmin) + ramp_up_tw = itertools.takewhile(lambda x: x < bstep and x <= bmax, \ + ramp_up_acc) + stable = range(bstep, bmax + 1, bstep) + buckets = list(ramp_up_tw) + list(stable) + return list(filter(lambda bucket: bucket >= bmin, buckets)) + + +def generate_prompt_buckets(bs_bucket_config, + seq_bucket_config, + max_num_batched_tokens=None): + buckets = list( + itertools.product(warmup_range(bs_bucket_config), + warmup_range(seq_bucket_config))) + if len(buckets) == 0: + msg = ("No buckets could be captured with following config " + f"(min, step, max_warmup): " + f"bs:{bs_bucket_config}, " + f"seq:{seq_bucket_config}") + raise ValueError(msg) + + filtered_buckets = buckets + if max_num_batched_tokens is not None: + # Remove buckets exceeding batch token budget + filtered_buckets = list( + filter( + lambda bucket: bucket[0] * bucket[1] <= max_num_batched_tokens, + buckets)) + + if len(filtered_buckets) == 0: + # we can handle this if we ignore max_num_batched_tokens + min_bucket_bs, min_bucket_seq = min(buckets, + key=lambda b: (b[0] * b[1])) + min_reqd_budget = min_bucket_bs * min_bucket_seq + msg = ( + "The current bucketing configuration " + f"(min, step, max_warmup): " + f"bs:{bs_bucket_config}, " + f"seq:{seq_bucket_config} cannot be used with specified " + f"max_num_batched_tokens ({max_num_batched_tokens}), as the " + f"smallest bucket ({min_reqd_budget}) would exceed token " + "budget. Please increase max_num_batched_tokens or decrease " + "bucket minimum Ignoring max_num_batched_tokens at risk of " + "out-of-memory errors.") + logger.error(msg) + return list( + sorted(buckets, key=lambda b: (b[0] * b[1], b[1], b[0]))), [] + + captured_buckets = list( + sorted(filtered_buckets, key=lambda b: (b[0] * b[1], b[1], b[0]))) + omitted_buckets = list( + sorted([x for x in buckets if x not in filtered_buckets])) + return captured_buckets, omitted_buckets + + +def generate_decode_buckets(bs_bucket_config, blocks_bucket_config, + max_blocks): + buckets = [] + bs_buckets = warmup_range(bs_bucket_config) + block_buckets = warmup_range(blocks_bucket_config) + bmin, bstep, bmax = blocks_bucket_config + last_bucket = round_up(max_blocks, bstep) + for bs in bs_buckets: + for blocks in block_buckets: + if blocks < bs: + continue + if blocks > last_bucket: + break + buckets.append((bs, blocks)) + return list(sorted(buckets, key=lambda b: (b[0] * b[1], b[1], b[0]))) + + +def next_pow2(value: int, base: int): + res = base + while value > 1: + value = (value + 1) // 2 + res *= 2 + return res + + +def round_up(value: int, k: int): + return (value + k - 1) // k * k + + +def find_bucket(value: int, config: Tuple[int, int, int]): + bmin, bstep, _ = config + next_step = round_up(value, bstep) + next_pow = next_pow2(value, bmin) + return max(bmin, min(next_step, next_pow)) + + +def align_workers(value, op): + group = get_world_group().cpu_group + world_size = torch.distributed.get_world_size() + if world_size <= 1: + return value + value_t = torch.tensor(value, device='cpu') + torch.distributed.all_reduce(value_t, op=op, group=group) + return value_t.item() + + +def setup_profiler(): + schedule = torch.profiler.schedule(wait=0, warmup=2, active=1, repeat=1) + DEVICE = 'hpu' + activities = [torch.profiler.ProfilerActivity.CPU] + activities.extend([torch.profiler.ProfilerActivity.HPU] if DEVICE == + 'hpu' else []) + #from habana_frameworks.torch.activity_profiler import DebugActivity + #debug_activities=[DebugActivity.BRIDGE_FUNCTION_CALLS] + + profiler = torch.profiler.profile( + schedule=schedule, + activities=activities, + #debug_activities=debug_activities, + on_trace_ready=torch.profiler.tensorboard_trace_handler('.', + use_gzip=True), + record_shapes=False, + with_stack=True) + return profiler + + +def pad_list(list, k, v): + target_len = round_up(len(list), k) + padding = target_len - len(list) + return list + [v] * padding + + +def precompute_indices_and_offsets(block_size, slot_mapping, is_prompt): + slot_mapping = slot_mapping.flatten() + indices = torch.div(slot_mapping, block_size, rounding_mode="floor") + if is_prompt: + indices = indices.unflatten(0, (-1, block_size))[:, 0] + offsets = None + else: + offsets = torch.fmod(slot_mapping, block_size) + return indices, offsets + + +class HpuModelAdapter(): + + def __init__(self, model, block_size, dtype, enforce_eager): + self.model = model + self.prefill_use_fusedsdpa = os.getenv('VLLM_PROMPT_USE_FUSEDSDPA', + '0').lower() in ['1', 'true'] + self.block_size = block_size + self.dtype = dtype + if not htorch.utils.internal.is_lazy() and not enforce_eager: + self.model = torch.compile(self.model, + backend='hpu_backend', + dynamic=False) + + def _set_attn_bias(self, attn_metadata, batch_size, seq_len, device, + dtype): + prefill_metadata = attn_metadata + if prefill_metadata is None or self.prefill_use_fusedsdpa: + return attn_metadata + + seq_lens_t = prefill_metadata.seq_lens_tensor + len_mask = (torch.arange(0, seq_len, device=device, + dtype=torch.int32).view(1, seq_len).ge( + seq_lens_t.unsqueeze(-1)).view( + batch_size, 1, 1, seq_len)) + causal_mask = torch.triu(torch.ones((batch_size, 1, seq_len, seq_len), + device=device, + dtype=torch.bool), + diagonal=1) + mask = causal_mask.logical_or(len_mask) + attn_bias = (torch.zeros_like(mask, dtype=dtype).masked_fill_( + mask, -math.inf)) + attn_metadata = prefill_metadata._replace(attn_bias=attn_bias) + return attn_metadata + + def _set_block_mapping(self, metadata, batch_size, device, dtype): + mask = torch.arange(0, + self.block_size, + device=device, + dtype=torch.int32).unsqueeze(0) + mask = mask >= metadata.block_usage.unsqueeze(-1) + attn_bias = (torch.zeros_like(mask, dtype=dtype).masked_fill_( + mask, -math.inf)) + block_mapping = torch.nn.functional.one_hot(metadata.block_mapping, + num_classes=batch_size) + block_mapping = block_mapping.to(dtype) + metadata = metadata._replace(block_mapping=block_mapping, + attn_bias=attn_bias) + return metadata + + def _update_metadata(self, attn_metadata, batch_size, seq_len, device, + dtype): + if attn_metadata.is_prompt: + meta = attn_metadata + attn_metadata = self._set_attn_bias(meta, batch_size, seq_len, + device, dtype) + else: + meta = attn_metadata + attn_metadata = self._set_block_mapping(meta, batch_size, device, + dtype) + return attn_metadata + + def forward(self, *args, **kwargs): + kwargs = kwargs.copy() + selected_token_indices = kwargs.pop('selected_token_indices') + if 'warmup_mode' in kwargs: + kwargs.pop('warmup_mode') + input_ids = kwargs['input_ids'] + kwargs['attn_metadata'] = self._update_metadata( + kwargs['attn_metadata'], input_ids.size(0), input_ids.size(1), + input_ids.device, self.dtype) + LoraMask.setLoraMask(kwargs.pop('lora_mask')) + hidden_states = self.model(*args, **kwargs) + hidden_states = hidden_states.view(-1, hidden_states.shape[-1]) + hidden_states = hidden_states.index_select(0, selected_token_indices) + return hidden_states + + def compute_logits(self, *args, **kwargs): + return self.model.compute_logits(*args, **kwargs) + + def sample(self, *args, **kwargs): + return self.model.sample(*args, **kwargs) + + +class PreparePromptMetadata(NamedTuple): + input_tokens: torch.Tensor + input_positions: List[List[int]] + attn_metadata: Optional[AttentionMetadata] + seq_lens: List[int] + query_lens: List[int] + lora_index_mapping: List[List[int]] + lora_prompt_mapping: List[List[int]] + lora_requests: Set[LoRARequest] + multi_modal_kwargs: Optional[Dict[str, BatchedTensorInputs]] + slot_mapping: List[List[int]] + lora_ids: List[int] + + @classmethod + def empty(cls): + return PreparePromptMetadata(input_tokens=[], + input_positions=[], + attn_metadata=None, + seq_lens=[], + query_lens=[], + lora_index_mapping=[], + lora_prompt_mapping=[], + lora_requests=set(), + multi_modal_kwargs=None, + slot_mapping=[], + lora_ids=[]) + + +class PrepareDecodeMetadata(NamedTuple): + input_tokens: torch.Tensor + input_positions: List[List[int]] + attn_metadata: Optional[AttentionMetadata] + lora_index_mapping: List[List[int]] + lora_prompt_mapping: List[List[int]] + lora_requests: Set[LoRARequest] + slot_mapping: List[List[int]] + lora_ids: List[int] + + @classmethod + def empty(cls): + return PrepareDecodeMetadata(input_tokens=[], + input_positions=[], + attn_metadata=None, + lora_index_mapping=[], + lora_prompt_mapping=[], + lora_requests=set(), + slot_mapping=[], + lora_ids=[]) + + +# How batches are constructed. +class BatchType(IntEnum): + # Every batch is prefill. + PREFILL = 0 + # Every batch is decode. + DECODE = 1 + # Batch is a mixture of prefill and decode. + MIXED = 2 + + +TModelInputForHPU = TypeVar('TModelInputForHPU', bound="ModelInputForHPU") + + +@dataclasses.dataclass(frozen=True) +class ModelInputForHPU(ModelRunnerInputBase): + """ + This base class contains metadata needed for the base model forward pass + but not metadata for possible additional steps, e.g., sampling. Model + runners that run additional steps should subclass this method to add + additional fields. + """ + input_tokens: Optional[torch.Tensor] = None + input_positions: Optional[torch.Tensor] = None + seq_lens: Optional[List[int]] = None + query_lens: Optional[List[int]] = None + lora_mapping: Optional["LoRAMapping"] = None + lora_requests: Optional[Set[LoRARequest]] = None + attn_metadata: Optional["AttentionMetadata"] = None + multi_modal_kwargs: Optional[Dict[str, torch.Tensor]] = None + real_batch_size: Optional[int] = None + batch_size_padded: Optional[int] = None + virtual_engine: int = 0 + lora_ids: Optional[List[int]] = None + async_callback: Optional[Callable] = None + + def as_broadcastable_tensor_dict(self) -> Dict[str, Any]: + tensor_dict = { + "input_tokens": self.input_tokens, + "input_positions": self.input_positions, + "lora_requests": self.lora_requests, + "lora_mapping": self.lora_mapping, + "multi_modal_kwargs": self.multi_modal_kwargs, + "real_batch_size": self.real_batch_size, + "batch_size_padded": self.batch_size_padded, + "virtual_engine": self.virtual_engine, + "lora_ids": self.lora_ids, + } + _add_attn_metadata_broadcastable_dict(tensor_dict, self.attn_metadata) + return tensor_dict + + @classmethod + def from_broadcasted_tensor_dict( + cls: Type[TModelInputForHPU], + tensor_dict: Dict[str, Any], + attn_backend: Optional["AttentionBackend"] = None, + ) -> TModelInputForHPU: + if attn_backend is not None: + tensor_dict = _init_attn_metadata_from_tensor_dict( + attn_backend, tensor_dict) + return cls(**tensor_dict) + + +@dataclasses.dataclass(frozen=True) +class ModelInputForHPUWithSamplingMetadata(ModelInputForHPU): + """ + Used by the ModelRunner. + """ + sampling_metadata: Optional["SamplingMetadata"] = None + # Used for speculative decoding. We do not broadcast it because it is only + # used by the driver worker. + is_prompt: Optional[bool] = None + + def as_broadcastable_tensor_dict(self) -> Dict[str, Any]: + tensor_dict = { + "input_tokens": self.input_tokens, + "input_positions": self.input_positions, + "lora_requests": self.lora_requests, + "lora_mapping": self.lora_mapping, + "multi_modal_kwargs": self.multi_modal_kwargs, + "lora_ids": self.lora_ids, + } + _add_attn_metadata_broadcastable_dict(tensor_dict, self.attn_metadata) + _add_sampling_metadata_broadcastable_dict(tensor_dict, + self.sampling_metadata) + return tensor_dict + + @classmethod + def from_broadcasted_tensor_dict( + cls, + tensor_dict: Dict[str, Any], + attn_backend: Optional["AttentionBackend"] = None, + ) -> "ModelInputForHPUWithSamplingMetadata": + tensor_dict = _init_sampling_metadata_from_tensor_dict(tensor_dict) + # FIXME(kzawora): this fails for whatever reason - why? + if attn_backend is not None: + tensor_dict = _init_attn_metadata_from_tensor_dict( + attn_backend, tensor_dict) + return cls(**tensor_dict) + + +class HPUModelRunnerBase(ModelRunnerBase[TModelInputForHPU]): + """ + Helper class for shared methods between GPU model runners. + """ + _model_input_cls: Type[TModelInputForHPU] + + def __init__( + self, + model_config: ModelConfig, + parallel_config: ParallelConfig, + scheduler_config: SchedulerConfig, + device_config: DeviceConfig, + cache_config: CacheConfig, + load_config: LoadConfig, + lora_config: Optional[LoRAConfig], + kv_cache_dtype: Optional[str] = "auto", + is_driver_worker: bool = False, + prompt_adapter_config: Optional[PromptAdapterConfig] = None, + return_hidden_states: bool = False, + observability_config: Optional[ObservabilityConfig] = None, + ): + self.model_config = model_config + self.parallel_config = parallel_config + self.scheduler_config = scheduler_config + self.device_config = device_config + self.cache_config = cache_config + self.lora_config = lora_config + self.load_config = load_config + self.is_driver_worker = is_driver_worker + self.prompt_adapter_config = prompt_adapter_config + self.return_hidden_states = return_hidden_states + self.observability_config = observability_config + + self.sliding_window = (model_config.get_sliding_window() + if model_config is not None else None) + self.device_config = (device_config + if device_config is not None else DeviceConfig()) + + self.device = self.device_config.device + self.enforce_eager = self.model_config.enforce_eager + self.max_num_seqs = self.scheduler_config.max_num_seqs + # NOTE(kzawora): Change that to scheduler_config.max_num_prefill_seqs + # once padding-aware scheduling gets merged + self.max_num_prefill_seqs = 64 + self.max_model_len = self.scheduler_config.max_model_len + self.max_num_batched_tokens = \ + self.scheduler_config.max_num_batched_tokens + self.block_size = cache_config.block_size + + self.pin_memory = is_pin_memory_available() + self.kv_cache_dtype = kv_cache_dtype + + self.attn_backend = get_attn_backend( + self.model_config.get_head_size(), + self.model_config.get_sliding_window(), + self.model_config.dtype, + self.kv_cache_dtype, + self.block_size, + self.model_config.is_attention_free, + ) + + # Lazy initialization + self.lora_manager: LRUCacheWorkerLoRAManager = None + self.model: torch.nn.Module = None + self.inc_initialized_successfully = False + + # Profiler stats + self.profiler = HabanaHighLevelProfiler() + self.profiler_counter_helper = HabanaProfilerCounterHelper() + self.seen_configs: set = set() + self._mem_margin: Optional[int] = None + self.bucketing_global_state = HPUBucketingGlobalState() + self._setup_buckets() + self._set_gc_threshold() + + def _set_gc_threshold(self) -> None: + # Read https://docs.python.org/3/library/gc.html#gc.set_threshold + # for comprehensive description of gc generations. + # We can either use VLLM_GC_THR_GEN[0-2] (this has higher priority) + # to set particular generation threshold or use simpler + # VLLM_GC_THR_MULTIPLIER to multiply default values. + default_gc_thrs = list(gc.get_threshold()) + requested_gc_thrs = [0] * len(default_gc_thrs) + for i in range(len(default_gc_thrs)): + requested_gc_thrs[i] = int( + os.environ.get(f'VLLM_GC_THR_GEN{i}', default_gc_thrs[i])) + if requested_gc_thrs == default_gc_thrs: + gc_thr_multiplier = int(os.environ.get('VLLM_GC_THR_MULTIPLIER', + 2)) + requested_gc_thrs = [ + t * gc_thr_multiplier for t in default_gc_thrs + ] + gc.set_threshold(*requested_gc_thrs) + + # Multi-modal data support + self.multi_modal_input_mapper = MULTIMODAL_REGISTRY \ + .create_input_mapper(self.model_config) + + self.skip_warmup = os.environ.get('VLLM_SKIP_WARMUP', + 'false').lower() == 'true' + + def load_model(self) -> None: + import habana_frameworks.torch.core as htcore + if self.model_config.quantization == 'inc' or \ + self.model_config.quantization == 'fp8': + htcore.hpu_set_env() + with HabanaMemoryProfiler() as m: + with HabanaMemoryProfiler() as m_getmodel: + self.model = get_model(model_config=self.model_config, + device_config=self.device_config, + load_config=self.load_config, + lora_config=self.lora_config, + parallel_config=self.parallel_config, + scheduler_config=self.scheduler_config, + cache_config=self.cache_config) + msg = ("Pre-loading model weights on " + f"{next(self.model.parameters()).device} " + f"took {m_getmodel.get_summary_string()}") + logger.info(msg) + + if self.lora_config: + assert hasattr(self.model, "supported_lora_modules" + ) and self.model.supported_lora_modules, ( + "Model does not support LoRA") + assert hasattr(self.model, "embedding_modules" + ), "Model does not have embedding_modules" + assert hasattr( + self.model, "embedding_padding_modules" + ), "Model does not have embedding_padding_modules" + self.lora_manager = LRUCacheWorkerLoRAManager( + self.scheduler_config.max_num_seqs, + self.scheduler_config.max_num_batched_tokens, + self.vocab_size, self.lora_config, self.device, + self.model.embedding_modules, + self.model.embedding_padding_modules) + self.model = self.lora_manager.create_lora_manager(self.model) + + if self.model_config.quantization == 'inc': + logger.info("Preparing model with INC..") + with HabanaMemoryProfiler() as m_inc: + from neural_compressor.torch.quantization import ( + FP8Config, convert, prepare) + config = FP8Config.from_json_file( + os.getenv("QUANT_CONFIG", "")) + if config.measure: + self.model = prepare(self.model, config) + elif config.quantize: + self.model = convert(self.model, config) + htcore.hpu_initialize(self.model, + mark_only_scales_as_const=True) + self.inc_initialized_successfully = True + logger.info("Preparing model with INC took %s", + m_inc.get_summary_string()) + else: + self.model = self.model.to("hpu") + htcore.mark_step() + torch.hpu.synchronize() + + with HabanaMemoryProfiler() as m_wrap: + self.model = _maybe_wrap_in_hpu_graph( + self.model, + self.block_size, + dtype=self.model_config.dtype, + enforce_eager=self.enforce_eager) + msg = f"Wrapping in HPU Graph took {m_wrap.get_summary_string()}" + logger.info(msg) + + self.model_memory_usage = m.consumed_device_memory + msg = f"Loading model weights took in total {m.get_summary_string()}" + logger.info(msg) + + def _use_graphs(self, batch_size, seq_len, is_prompt): + if self.enforce_eager: + return False + if self.skip_warmup: + return True + return (batch_size, seq_len, is_prompt) in self.graphed_buckets + + def _is_valid_bucket(self, bucket): + return bucket[0] * bucket[1] <= self.max_num_batched_tokens + + def _setup_buckets(self) -> None: + align_bs = lambda x: min(self.max_num_seqs, x) + #FIXME: The default values should be max_model_len + max_prompt_seq = 1024 + max_decode_seq = 2048 + self.bucketing_global_state.prompt_bs_bucket_cfg = read_bucket_settings( + 'prompt', + 'bs', + min=1, + step=align_bs(32), + max=self.max_num_prefill_seqs) + self.bucketing_global_state.decode_bs_bucket_cfg = read_bucket_settings( + 'decode', 'bs', min=1, step=align_bs(32), max=self.max_num_seqs) + self.bucketing_global_state.prompt_seq_bucket_cfg = \ + read_bucket_settings( + 'prompt', + 'seq', + min=self.block_size, + step=self.block_size, + max=max_prompt_seq) + self.bucketing_global_state.decode_block_bucket_cfg = \ + read_bucket_settings( + 'decode', + 'block', + min=self.block_size, + step=self.block_size, + max=max(self.block_size, + self.max_num_seqs * max_decode_seq // self.block_size)) + self.graphed_buckets: Set[Any] = set() + + msg = ("Prompt bucket config (min, step, max_warmup) " + f"bs:{self.bucketing_global_state.prompt_bs_bucket_cfg}, " + f"seq:{self.bucketing_global_state.prompt_seq_bucket_cfg}") + logger.info(msg) + + msg = ("Decode bucket config (min, step, max_warmup) " + f"bs:{self.bucketing_global_state.decode_bs_bucket_cfg}, " + f"block:{self.bucketing_global_state.decode_block_bucket_cfg}") + logger.info(msg) + + def _prepare_prompt( + self, + seq_group_metadata_list: List[SequenceGroupMetadata], + ) -> PreparePromptMetadata: + input_tokens: List[List[int]] = [] + input_positions: List[List[int]] = [] + slot_mapping: List[List[int]] = [] + lora_index_mapping: List[List[int]] = [] + lora_prompt_mapping: List[List[int]] = [] + lora_requests: Set[LoRARequest] = set() + + seq_lens: List[int] = [] + context_lens: List[int] = [] + query_lens: List[int] = [] + prefix_block_tables: List[List[int]] = [] + multi_modal_inputs_list: List[MultiModalInputs] = [] + + if len(seq_group_metadata_list) == 0: + return PreparePromptMetadata.empty() + + for seq_group_metadata in seq_group_metadata_list: + assert seq_group_metadata.is_prompt + seq_ids = list(seq_group_metadata.seq_data.keys()) + assert len(seq_ids) == 1 + seq_id = seq_ids[0] + + computed_block_nums = seq_group_metadata.computed_block_nums + if (self.scheduler_config is not None + and self.scheduler_config.chunked_prefill_enabled + and not (computed_block_nums is None + or computed_block_nums == [])): + raise RuntimeError( + "chunked prefill cannot be used with prefix caching " + "now.") + + token_chunk_size = seq_group_metadata.token_chunk_size + seq_data = seq_group_metadata.seq_data[seq_id] + context_len = seq_data.get_num_computed_tokens() + # We should use get_len here because in case of preemption + # it contains output tokens. + seq_len = min(seq_data.get_len(), context_len + token_chunk_size) + prompt_tokens = seq_data.get_token_ids()[context_len:seq_len] + seq_lens.append(seq_len) + + # NOTE: This only works for oooooooxxx style attention. + if computed_block_nums is not None and len( + computed_block_nums) > 0 and self.sliding_window is None: + # Prefix is not supported with sliding_window + context_len = len(computed_block_nums) * self.block_size + prompt_tokens = prompt_tokens[context_len:] + prefix_block_tables.append(computed_block_nums) + elif self.scheduler_config.chunked_prefill_enabled: + if seq_group_metadata.block_tables is not None: + # Prefill has chunked before. + block_table = seq_group_metadata.block_tables[seq_id] + prefix_block_tables.append(block_table) + else: + # The first prefill. + prefix_block_tables.append([]) + else: + prefix_block_tables.append([]) + # Right now, prefill start is always 0. However, this + # assumption can be changed once chunked prefill is introduced. + assert context_len == 0 + + # actual prompt lens + context_lens.append(context_len) + query_lens.append(seq_len - context_len) + input_tokens.append(prompt_tokens) + # NOTE(woosuk): Here we assume that the first token in the prompt + # is always the first token in the sequence. + input_positions.append(list(range(context_len, seq_len))) + + mm_data = seq_group_metadata.multi_modal_data + if mm_data: + mm_kwargs = self.multi_modal_input_mapper(mm_data) + multi_modal_inputs_list.append(mm_kwargs) + + if seq_group_metadata.block_tables is None: + # During memory profiling, the block tables are not initialized + # yet. In this case, we just use a dummy slot mapping. + slot_mapping.append([_PAD_SLOT_ID] * seq_len) + continue + + # Compute the slot mapping. + slot_mapping.append([]) + block_table = seq_group_metadata.block_tables[seq_id] + + # Mask the [0, start_idx) tokens of the prompt with _PAD_SLOT_ID, + # where start_idx is max(0, seq_len - sliding_window). + # For example, if the prompt len is 10, sliding window is 8, and + # block size is 4, the first two tokens are masked and the slot + # mapping will be [-1, -1, 2, 3, 4, 5, 6, 7, 0, 1]. + start_idx = 0 + if self.sliding_window is not None: + assert context_len == 0, ( + "Prefix caching is currently not supported with " + "sliding window attention") + start_idx = max(0, seq_len - self.sliding_window) + for i in range(context_len, seq_len): + if i < start_idx: + slot_mapping[-1].append(_PAD_SLOT_ID) + continue + + block_number = block_table[i // self.block_size] + block_offset = i % self.block_size + slot = block_number * self.block_size + block_offset + slot_mapping[-1].append(slot) + + max_query_len = max(query_lens) + sum_query_len = sum(query_lens) + real_num_seqs = len(query_lens) + assert max_query_len > 0 + + max_prompt_len = max( + find_bucket(max(seq_lens), + self.bucketing_global_state.prompt_seq_bucket_cfg), + self.block_size) + + lora_ids: List[int] = [] + for seq_group_metadata, context_len in zip(seq_group_metadata_list, + context_lens): + lora_id = seq_group_metadata.lora_int_id + lora_ids.append(lora_id) + + if lora_id > 0: + lora_requests.add(seq_group_metadata.lora_request) + + lora_index_mapping += [lora_id] * (max_prompt_len - context_len) + lora_prompt_mapping.extend( + [lora_id] * + (max_prompt_len - context_len + if seq_group_metadata.sampling_params.prompt_logprobs else 1)) + + input_tokens = make_tensor_with_pad(input_tokens, + max_len=max_prompt_len, + pad=0, + dtype=torch.long, + device=self.device) + + input_positions = make_tensor_with_pad(input_positions, + max_len=max_prompt_len, + pad=0, + dtype=torch.long, + device=self.device) + + slot_mapping = make_tensor_with_pad(slot_mapping, + max_len=max_prompt_len, + pad=_PAD_SLOT_ID, + dtype=torch.long, + device=self.device) + + seq_lens_tensor = torch.tensor(seq_lens, + dtype=torch.long, + device=self.device) + + block_indices, block_offsets = precompute_indices_and_offsets( + self.block_size, slot_mapping, True) + attn_metadata = self.attn_backend.make_metadata( + is_prompt=True, + block_list=None, + block_mapping=None, + block_usage=None, + block_indices=block_indices, + block_offsets=block_offsets, + block_scales=None, + attn_bias=None, + seq_lens_tensor=seq_lens_tensor, + num_prefills=real_num_seqs, + num_prefill_tokens=sum_query_len, + num_decode_tokens=0, + slot_mapping=slot_mapping, + ) + multi_modal_kwargs = MultiModalInputs.batch(multi_modal_inputs_list) + + return PreparePromptMetadata(input_tokens=input_tokens, + input_positions=input_positions, + attn_metadata=attn_metadata, + seq_lens=seq_lens, + query_lens=query_lens, + lora_index_mapping=lora_index_mapping, + lora_prompt_mapping=lora_prompt_mapping, + lora_requests=lora_requests, + multi_modal_kwargs=multi_modal_kwargs, + slot_mapping=slot_mapping, + lora_ids=lora_ids) + + def _prepare_decode( + self, + seq_group_metadata_list: List[SequenceGroupMetadata], + ) -> PrepareDecodeMetadata: + input_tokens: List[List[int]] = [] + input_positions: List[List[int]] = [] + slot_mapping: List[List[int]] = [] + seq_lens: List[int] = [] + block_tables: List[List[int]] = [] + lora_index_mapping: List[List[int]] = [] + lora_prompt_mapping: List[List[int]] = [] + lora_requests: Set[LoRARequest] = set() + + if len(seq_group_metadata_list) == 0: + return PrepareDecodeMetadata.empty() + lora_ids: List[int] = [] + + dummy_slots = itertools.cycle( + range(_PAD_SLOT_ID, _PAD_SLOT_ID + self.block_size)) + + for seq_group_metadata in seq_group_metadata_list: + assert not seq_group_metadata.is_prompt + assert seq_group_metadata.token_chunk_size == 1 + + seq_ids = list(seq_group_metadata.seq_data.keys()) + lora_id = seq_group_metadata.lora_int_id + lora_ids.append(lora_id) + + if lora_id > 0: + lora_requests.add(seq_group_metadata.lora_request) + + for seq_id in seq_ids: + seq_data = seq_group_metadata.seq_data[seq_id] + generation_token = seq_data.get_last_token_id() + input_tokens.append([generation_token]) + + seq_len = seq_data.get_len() + position = seq_len - 1 + input_positions.append([position]) + + seq_len = seq_len if self.sliding_window is None else min( + seq_len, self.sliding_window) + seq_lens.append(seq_len) + + block_table = seq_group_metadata.block_tables[seq_id] + if len(block_table) == 0: + block_number = _PAD_BLOCK_ID + else: + block_number = block_table[position // self.block_size] + if block_number == _PAD_BLOCK_ID: + slot = next(dummy_slots) + else: + block_offset = position % self.block_size + slot = block_number * self.block_size + block_offset + slot_mapping.append([slot]) + lora_index_mapping.append(lora_id) + lora_prompt_mapping.append(lora_id) + + if self.sliding_window is not None: + sliding_window_blocks = (self.sliding_window // + self.block_size) + block_table = block_table[-sliding_window_blocks:] + block_tables.append(block_table) + + input_tokens = torch.tensor(input_tokens, + dtype=torch.long, + device=self.device) + input_positions = torch.tensor(input_positions, + dtype=torch.long, + device=self.device) + + num_decode_tokens = sum(seq_lens) + + blocks_used = [len(bt) for bt in block_tables if bt] + block_list = [] + block_scales = [] + for i, bt in enumerate(block_tables): + block_list.extend(bt) + blocks_in_group = len(bt) + if blocks_in_group > 0: + scale = 1.0 / blocks_in_group + block_scales.extend([scale] * blocks_in_group) + + block_mapping_nested: List[List[int]] = [ + [i] * b_u for i, b_u in enumerate(blocks_used) + ] + block_mapping: List[int] = list( + itertools.chain.from_iterable(block_mapping_nested)) + + last_block = [ + sl % self.block_size + 1 for sl in itertools.chain(*slot_mapping) + ] + block_usage = [[self.block_size] * (b_u - 1) + [lb] + for b_u, lb in zip(blocks_used, last_block)] + block_usage = list(itertools.chain(*block_usage)) + + block_bucket_size = find_bucket( + len(block_list), + self.bucketing_global_state.decode_block_bucket_cfg) + block_list = pad_list(block_list, block_bucket_size, _PAD_BLOCK_ID) + block_mapping = pad_list(block_mapping, block_bucket_size, -1) + block_usage = pad_list(block_usage, block_bucket_size, 1) + block_scales = pad_list(block_scales, block_bucket_size, 0.0) + + block_list = torch.tensor(block_list, + dtype=torch.int, + device=self.device) + block_mapping = torch.tensor(block_mapping, + dtype=torch.long, + device=self.device) + block_usage = torch.tensor(block_usage, + dtype=self.model_config.dtype, + device=self.device) + + slot_mapping = torch.tensor(slot_mapping, + dtype=torch.long, + device=self.device) + + block_indices, block_offsets = precompute_indices_and_offsets( + self.block_size, slot_mapping, False) + block_scales = torch.tensor(block_scales, + dtype=self.model_config.dtype, + device=self.device) + + attn_metadata = self.attn_backend.make_metadata( + is_prompt=False, + block_list=block_list, + block_mapping=block_mapping, + block_usage=block_usage, + block_indices=block_indices, + block_offsets=block_offsets, + block_scales=block_scales, + attn_bias=None, + seq_lens_tensor=None, + num_prefills=0, + num_prefill_tokens=0, + num_decode_tokens=num_decode_tokens, + slot_mapping=slot_mapping, + ) + return PrepareDecodeMetadata(input_tokens=input_tokens, + input_positions=input_positions, + attn_metadata=attn_metadata, + lora_index_mapping=lora_index_mapping, + lora_prompt_mapping=lora_prompt_mapping, + lora_requests=lora_requests, + slot_mapping=slot_mapping, + lora_ids=lora_ids) + + def prepare_input_tensors( + self, + seq_group_metadata_list: List[SequenceGroupMetadata], + ) -> Tuple[TModelInputForHPU, SamplingMetadata]: + if len(seq_group_metadata_list) == 0: + return self._model_input_cls(), None + + input_tokens = None + input_positions = None + lora_mapping = None + lora_requests = None + multi_modal_kwargs = None + batch_type = None + seq_lens = None + query_lens = None + real_batch_size = None + batch_size_padded = None + + self.event_start = self.profiler.get_timestamp_us() + is_prompt = seq_group_metadata_list[0].is_prompt + base_event_name = 'prompt' if is_prompt else 'decode' + self.profiler.start('internal', base_event_name) + + real_batch_size = len(seq_group_metadata_list) + bucket_cfg = self.bucketing_global_state.prompt_bs_bucket_cfg \ + if is_prompt else self.bucketing_global_state.decode_bs_bucket_cfg + batch_size_padded = find_bucket(real_batch_size, bucket_cfg) + batch_size_padding = batch_size_padded - real_batch_size + seq_group_metadata_list = seq_group_metadata_list.copy() + if batch_size_padding > 0: + dummy_seq_group_metadata = self.create_dummy_seq_group_metadata( + 0, 0, is_prompt) + seq_group_metadata_list.extend(dummy_seq_group_metadata + for _ in range(batch_size_padding)) + + prefill_reqs = [] + decode_reqs = [] + for seq_group_meta in seq_group_metadata_list: + if seq_group_meta.is_prompt: + prefill_reqs.append(seq_group_meta) + else: + decode_reqs.append(seq_group_meta) + + # Prepare input tensors. + ( + input_tokens, + input_positions, + prefill_attn_metadata, + seq_lens, + query_lens, + lora_index_mapping, + lora_prompt_mapping, + lora_requests, + multi_modal_kwargs, + slot_mapping, + lora_ids, + ) = self._prepare_prompt(prefill_reqs) + ( + decode_input_tokens, + decode_input_positions, + decode_attn_metadata, + decode_lora_index_mapping, + decode_lora_prompt_mapping, + decode_lora_requests, + decode_slot_mapping, + decode_lora_ids, + ) = self._prepare_decode(decode_reqs) + sampling_metadata = SamplingMetadata.prepare(seq_group_metadata_list, + seq_lens, query_lens, + self.device, + self.pin_memory) + + if not self.scheduler_config.chunked_prefill_enabled: + assert (len(prefill_reqs) and len(decode_reqs)) == 0 + + num_prefills = len(seq_lens) + num_prefill_tokens = len(input_tokens) + num_decode_tokens = len(decode_input_tokens) + + # NOTE(kzawora): Here we diverge from GPU code - we don't + # support mixed batches, so we either use decode or prefill + # inputs, without coalescing. + assert (num_prefills == 0 and num_decode_tokens > 0) or ( + num_prefills > 0 + and num_decode_tokens == 0), "HPU does not support mixed batches!" + if num_decode_tokens > 0: + input_tokens = decode_input_tokens + input_positions = decode_input_positions + slot_mapping = decode_slot_mapping + lora_index_mapping = decode_lora_index_mapping + lora_prompt_mapping = decode_lora_prompt_mapping + lora_requests = decode_lora_requests + lora_ids = decode_lora_ids + + # FIXME: We need to adjust selected_token_indices to accommodate + # for padding + max_len = input_tokens.size(1) + paddings = [max_len - s for s in seq_lens] + paddings = [0] + paddings[:-1] + paddings = list(itertools.accumulate(paddings)) + paddings_prompt_logprobs = [] + for i, seq_group_metadata in enumerate(seq_group_metadata_list): + if seq_group_metadata.sampling_params.prompt_logprobs is not None \ + and seq_group_metadata.is_prompt: + paddings_prompt_logprobs += ([paddings[i]] * seq_lens[i]) + paddings = torch.tensor( + paddings_prompt_logprobs if paddings_prompt_logprobs else paddings, + dtype=sampling_metadata.selected_token_indices.dtype, + device=sampling_metadata.selected_token_indices.device) + sampling_metadata.selected_token_indices.add_(paddings) + + if self.lora_config: + lora_mapping = LoRAMapping( + **dict(index_mapping=lora_index_mapping, + prompt_mapping=lora_prompt_mapping, + is_prefill=(num_prefills > 0))) + else: + lora_mapping = None + + if (prefill_attn_metadata is not None + and decode_attn_metadata is not None): + batch_type = BatchType.MIXED + raise NotImplementedError("Mixed batch is not supported on HPU") + elif prefill_attn_metadata is not None: + batch_type = BatchType.PREFILL + else: + batch_type = BatchType.DECODE + + metadata_dict = { + "input_tokens": input_tokens, + "input_positions": input_positions, + "selected_token_indices": sampling_metadata.selected_token_indices, + "lora_requests": lora_requests, + "lora_mapping": lora_mapping, + "multi_modal_kwargs": multi_modal_kwargs, + "num_prefill_tokens": num_prefill_tokens, + "num_decode_tokens": num_decode_tokens, + "slot_mapping": slot_mapping, + "num_prefills": num_prefills, + "batch_type": batch_type, + "seq_lens": seq_lens, + "query_lens": query_lens + } + if prefill_attn_metadata is not None: + metadata_dict.update(prefill_attn_metadata.asdict_zerocopy()) + else: + assert decode_attn_metadata is not None + metadata_dict.update(decode_attn_metadata.asdict_zerocopy()) + + attn_metadata = prefill_attn_metadata if \ + prefill_attn_metadata is not None else decode_attn_metadata + + return self._model_input_cls(input_tokens=input_tokens, + seq_lens=seq_lens, + query_lens=query_lens, + input_positions=input_positions, + attn_metadata=attn_metadata, + lora_requests=lora_requests, + lora_mapping=lora_mapping, + multi_modal_kwargs=multi_modal_kwargs, + real_batch_size=real_batch_size, + batch_size_padded=batch_size_padded, + lora_ids=lora_ids), \ + sampling_metadata + + def _seq_len(self, attn_metadata): + if attn_metadata.num_prefills != 0: + return attn_metadata.slot_mapping.size(1) + else: + return attn_metadata.block_list.numel() + + def trim_attn_metadata(self, metadata: AttentionMetadata) -> object: + # NOTE(kzawora): To anyone working on this in the future: + # Trimming metadata is required when using HPUGraphs. + # Attention metadata is going to be hashed by PT bridge, and + # appropriate HPUGraphs will be matched based on all inputs' hash. + + # Before you put more keys in here, make sure you know their + # value type and make sure you know how it's going to be hashed. + # You can find that information in input_hash function + # in habana_frameworks/torch/hpu/graphs.py. You can also hash + # it manually with torch.hpu.graphs.input_hash(attention_metadata) + + # If you use primitive types here - they will get hashed based + # on their value. You *will* get lots of excessive graph captures + # (and an OOM eventually) if you decide to put something like + # seq_len int here. + # If you absolutely need a scalar, put it in a tensor. Tensors + # get hashed using their metadata, not their values: + # input_hash(torch.tensor(123)) == input_hash(torch.tensor(321)) + # input_hash(123) != input_hash(321) + # input_hash("abc") != input_hash("cba") + attention_metadata = subtuple(metadata, 'TrimmedAttentionMetadata', [ + 'attn_bias', 'seq_lens_tensor', 'block_list', 'block_mapping', + 'block_usage', 'slot_mapping', 'is_prompt', 'block_indices', + 'block_offsets', 'block_scales' + ]) + return attention_metadata + + def create_dummy_seq_group_metadata(self, + group_id, + seq_len, + is_prompt, + lora_request=None): + sampling_params = SamplingParams(temperature=0) + num_blocks = math.ceil(seq_len / self.block_size) + seq_len = max(seq_len, 1) + if is_prompt: + input_len = seq_len + output_len = 0 + block_tables = None + else: + input_len = seq_len - 1 + output_len = 1 + block_tables = {group_id: [_PAD_BLOCK_ID] * num_blocks} + prompt_token_ids = [0] * input_len + output_token_ids = [1] * output_len + prompt_token_ids_array = array('l', prompt_token_ids) # noqa: F821 + seq_data = SequenceData(prompt_token_ids_array) + seq_data.output_token_ids = output_token_ids + return SequenceGroupMetadata(request_id=str(group_id), + is_prompt=(output_len == 0), + seq_data={group_id: seq_data}, + sampling_params=sampling_params, + block_tables=block_tables, + lora_request=lora_request) + + def profile_run(self) -> None: + num_layers = self.model_config.get_num_layers(self.parallel_config) + kv_caches = [None] * num_layers + max_batch_size = self.bucketing_global_state.prompt_bs_bucket_cfg[-1] + max_seq_len = min( + self.bucketing_global_state.prompt_seq_bucket_cfg[-1], + self.max_num_batched_tokens // max_batch_size) + + self.warmup_scenario(max_batch_size, max_seq_len, True, kv_caches, + False, True) + return + + def warmup_scenario(self, + batch_size, + seq_len, + is_prompt, + kv_caches, + is_pt_profiler_run=False, + is_lora_profile_run=False) -> None: + use_graphs = self._use_graphs(batch_size, seq_len, is_prompt) + scenario_name = ("warmup_" + f"{'prompt' if is_prompt else 'decode'}_" + f"bs{batch_size}_" + f"seq{seq_len}_" + f"graphs{'T' if use_graphs else 'F'}") + max_num_seqs = self.scheduler_config.max_num_seqs + # This represents the maximum number of different requests + # that will have unique loras, an therefore the max amount of memory + # consumption create dummy lora request copies from the lora request + # passed in, which contains a lora from the lora warmup path. + dummy_lora_requests: List[LoRARequest] = [] + dummy_lora_requests_per_seq: List[LoRARequest] = [] + if self.lora_config and is_lora_profile_run: + assert self.lora_manager is not None + with self.lora_manager.dummy_lora_cache(): + for idx in range(self.lora_config.max_loras): + lora_id = idx + 1 + dummy_lora_request = LoRARequest( + lora_name=f"warmup_{lora_id}", + lora_int_id=lora_id, + lora_local_path="/not/a/real/path", + ) + self.lora_manager.add_dummy_lora(dummy_lora_request, + rank=LORA_WARMUP_RANK) + dummy_lora_requests.append(dummy_lora_request) + dummy_lora_requests_per_seq = [ + dummy_lora_requests[idx % len(dummy_lora_requests)] + for idx in range(max_num_seqs) + ] + self.profiler.start('internal', scenario_name) + times = 3 if use_graphs or is_pt_profiler_run else 1 + if self.lora_config and not is_lora_profile_run: + lora_mapping = LoRAMapping( + **dict(index_mapping=[0] * batch_size * seq_len, + prompt_mapping=[0] * batch_size * seq_len, + is_prefill=is_prompt)) + self.set_active_loras(set(), lora_mapping) + if is_prompt: + seqs = [ + self.create_dummy_seq_group_metadata( + i, + seq_len, + is_prompt, + lora_request=dummy_lora_requests_per_seq[i] + if dummy_lora_requests_per_seq else None) + for i in range(batch_size) + ] + else: + # FIXME: seq_len is actually number of blocks + blocks = [seq_len // batch_size for _ in range(batch_size)] + blocks[0] += seq_len % batch_size + seqs = [ + self.create_dummy_seq_group_metadata( + i, + b * self.block_size - 1, + is_prompt, + lora_request=dummy_lora_requests_per_seq[i] + if dummy_lora_requests_per_seq else None) + for i, b in enumerate(blocks) + ] + torch.hpu.synchronize() + profiler = None + if is_pt_profiler_run and self.is_driver_worker: + profiler = setup_profiler() + profiler.start() + for _ in range(times): + inputs = self.prepare_model_input(seqs) + self.execute_model(inputs, kv_caches, warmup_mode=True) + torch.hpu.synchronize() + if profiler: + profiler.step() + if profiler: + profiler.stop() + self.profiler.end() + gc.collect() + + def remove_all_loras(self): + if not self.lora_manager: + raise RuntimeError("LoRA is not enabled.") + self.lora_manager.remove_all_adapters() + + def set_active_loras(self, lora_requests: Set[LoRARequest], + lora_mapping: LoRAMapping) -> None: + if not self.lora_manager: + raise RuntimeError("LoRA is not enabled.") + self.lora_manager.set_active_adapters(lora_requests, lora_mapping) + + def add_lora(self, lora_request: LoRARequest) -> bool: + if not self.lora_manager: + raise RuntimeError("LoRA is not enabled.") + return self.lora_manager.add_adapter(lora_request) + + def remove_lora(self, lora_id: int) -> bool: + if not self.lora_manager: + raise RuntimeError("LoRA is not enabled.") + return self.lora_manager.remove_adapter(lora_id) + + def pin_lora(self, lora_id: int) -> bool: + if not self.lora_manager: + raise RuntimeError("LoRA is not enabled.") + return self.lora_manager.pin_adapter(lora_id) + + def list_loras(self) -> Set[int]: + if not self.lora_manager: + raise RuntimeError("LoRA is not enabled.") + return self.lora_manager.list_adapters() + + def log_warmup(self, phase, i, max_i, batch_size, seq_len): + free_mem = format_bytes( + HabanaMemoryProfiler.current_free_device_memory()) + dim = "num_blocks" + if phase == "Prompt": + dim = "seq_len" + msg = (f"[Warmup][{phase}][{i+1}/{max_i}] " + f"batch_size:{batch_size} " + f"{dim}:{seq_len} " + f"free_mem:{free_mem}") + logger.info(msg) + + def warmup_all_buckets(self, buckets, is_prompt, kv_caches): + for i, (batch_size, seq_len) in enumerate(reversed(buckets)): + self.log_warmup('Prompt' if is_prompt else 'Decode', i, + len(buckets), batch_size, seq_len) + self.warmup_scenario(batch_size, seq_len, is_prompt, kv_caches) + + def warmup_graphs(self, + strategy, + buckets, + is_prompt, + kv_caches, + available_mem, + starting_mem=0, + total_batch_seq=0.001): + total_mem = starting_mem + idx = 0 + phase = f'Graph/{"Prompt" if is_prompt else "Decode"}' + num_candidates = len(buckets) + ordering : Union[Callable[[Any], Tuple[Any, Any]], \ + Callable[[Any], Tuple[Any, Any, Any]]] + if strategy == 'min_tokens': + ordering = lambda b: (b[0] * b[1], b[1], b[0]) + elif strategy == 'max_bs': + ordering = lambda b: (-b[0], b[1]) + else: + raise NotImplementedError( + f'Unsupported graph allocation strategy: {strategy}') + buckets = list(sorted(buckets, key=ordering)) + captured_all = True + for idx, (batch_size, seq_len) in enumerate(buckets): + # Graph memory usage is proportional to seq dimension in a batch + batch_seq = batch_size * seq_len if is_prompt else batch_size + mem_estimate = batch_seq / total_batch_seq * total_mem + if mem_estimate >= available_mem: + captured_all = False + continue + graphed_bucket = (batch_size, seq_len, is_prompt) + if graphed_bucket in self.graphed_buckets: + continue + self.graphed_buckets.add(graphed_bucket) + self.log_warmup(phase, idx, num_candidates, batch_size, seq_len) + with HabanaMemoryProfiler() as mem_prof: + self.warmup_scenario(batch_size, seq_len, is_prompt, kv_caches) + used_mem = align_workers(mem_prof.consumed_device_memory, + torch.distributed.ReduceOp.MAX) + available_mem -= used_mem + total_mem += used_mem + total_batch_seq += batch_seq + + return total_mem, total_batch_seq, captured_all + + def log_graph_warmup_summary(self, buckets, is_prompt, total_mem): + num_candidates = len(buckets) + phase = f'Graph/{"Prompt" if is_prompt else "Decode"}' + graphed = list(c[:2] for c in self.graphed_buckets + if c[2] == is_prompt) + if num_candidates == 0: + num_candidates = 1 + msg = (f'{phase} captured:{len(graphed)} ' + f'({100 * len(graphed) / num_candidates:.1f}%) ' + f'used_mem:{format_bytes(total_mem)} ' + f'buckets:{sorted(list(graphed))}') + logger.info(msg) + + @torch.inference_mode() + def warmup_model(self, kv_caches: List[torch.Tensor]) -> None: + if profile := os.environ.get('VLLM_PT_PROFILE', None): + phase, bs, seq_len, graph = profile.split('_') + is_prompt = phase == 'prompt' + graphs = graph == 't' + if graphs: + self.graphed_buckets.add((int(bs), int(seq_len), is_prompt)) + self.warmup_scenario(int(bs), int(seq_len), is_prompt, kv_caches, + True) + raise AssertionError("Finished profiling") + if self.skip_warmup: + logger.info("Skipping warmup...") + return + self.profiler.start('internal', 'warmup') + max_blocks = kv_caches[0][0].size(0) + + self.bucketing_global_state.prompt_buckets, prompt_omitted_buckets = \ + generate_prompt_buckets( + self.bucketing_global_state.prompt_bs_bucket_cfg, + self.bucketing_global_state.prompt_seq_bucket_cfg, + self.max_num_batched_tokens) + + msg = (f"Generated {len(self.bucketing_global_state.prompt_buckets)} " + f"prompt buckets [bs, seq]: \ + {list(sorted(self.bucketing_global_state.prompt_buckets))}") + logger.info(msg) + + msg = (f"Omitted {len(prompt_omitted_buckets)} " + "prompt buckets due to exceeded token budget " + f"(max_num_batched_tokens={self.max_num_batched_tokens})") + logger.info(msg) + + msg = f"Omitted prompt buckets: {list(sorted(prompt_omitted_buckets))}" + logger.debug(msg) + + self.bucketing_global_state.decode_buckets = generate_decode_buckets( + self.bucketing_global_state.decode_bs_bucket_cfg, + self.bucketing_global_state.decode_block_bucket_cfg, max_blocks) + logger.info("Generated %d decode buckets [bs, total_blocks]: %s", + len(self.bucketing_global_state.decode_buckets), + list(sorted(self.bucketing_global_state.decode_buckets))) + + if not htorch.utils.internal.is_lazy() and not self.enforce_eager: + cache_size_limit = len( + self.bucketing_global_state.prompt_buckets) + len( + self.bucketing_global_state.decode_buckets) + 1 + torch._dynamo.config.cache_size_limit = max( + cache_size_limit, torch._dynamo.config.cache_size_limit) + # Multiply by 8 to follow the original default ratio between + # the cache_size_limit and accumulated_cache_size_limit + torch._dynamo.config.accumulated_cache_size_limit = max( + cache_size_limit * 8, + torch._dynamo.config.accumulated_cache_size_limit) + + start_mem = HabanaMemoryProfiler.current_device_memory_usage() + start_time = time.perf_counter() + + compile_only_mode_context = functools.partial(bc.env_setting, + "PT_COMPILE_ONLY_MODE", + True) + can_use_compile_only_mode = True + try: + with compile_only_mode_context(): + pass + logger.debug("Using PT_COMPILE_ONLY_MODE.") + except KeyError: + can_use_compile_only_mode = False + logger.warning('Cannot use PT_COMPILE_ONLY_MODE. ' + 'Warmup time will be negatively impacted. ' + 'Please update Gaudi Software Suite.') + with compile_only_mode_context( + ) if can_use_compile_only_mode else contextlib.nullcontext(): + self.warmup_all_buckets(self.bucketing_global_state.prompt_buckets, + True, kv_caches) + self.warmup_all_buckets(self.bucketing_global_state.decode_buckets, + False, kv_caches) + + if not self.enforce_eager and htorch.utils.internal.is_lazy(): + assert self.mem_margin is not None, \ + ("HabanaWorker.determine_num_available_blocks needs " + "to be called before warming up the model.") + free_mem = HabanaMemoryProfiler.current_free_device_memory() + graph_free_mem = free_mem - self.mem_margin + graph_free_mem = align_workers(graph_free_mem, + torch.distributed.ReduceOp.MIN) + prompt_graph_mem_ratio = float( + os.environ.get('VLLM_GRAPH_PROMPT_RATIO', '0.3')) + prompt_available_memory = (prompt_graph_mem_ratio * + graph_free_mem) + decode_available_memory = (graph_free_mem - + prompt_available_memory) + msg = ( + f"Using {format_bytes(graph_free_mem)}" + f"/{format_bytes(free_mem)} " + "of free device memory for HPUGraphs, " + f"{format_bytes(prompt_available_memory)} for prompt and " + f"{format_bytes(decode_available_memory)} for decode " + f"(VLLM_GRAPH_PROMPT_RATIO={prompt_graph_mem_ratio})") + logger.info(msg) + prompt_strategy = os.environ.get('VLLM_GRAPH_PROMPT_STRATEGY', + 'min_tokens') + decode_strategy = os.environ.get('VLLM_GRAPH_DECODE_STRATEGY', + 'max_bs') + mem_post_prompt, prompt_batch_seq, prompt_captured_all = \ + self.warmup_graphs( + prompt_strategy, self.bucketing_global_state.prompt_buckets, + True, kv_caches, prompt_available_memory) + mem_post_decode, decode_batch_seq, decode_captured_all = \ + self.warmup_graphs( + decode_strategy, self.bucketing_global_state.decode_buckets, + False, kv_caches, decode_available_memory) + + # Not all prompt buckets were captured, but all decode buckets + # were captured and we have some free graph-allocated space + # left. Let's try to use it for capturing more prompt buckets. + if (mem_post_decode + mem_post_prompt < graph_free_mem + and not prompt_captured_all and decode_captured_all): + mem_post_prompt, _, prompt_captured_all = ( + self.warmup_graphs( + prompt_strategy, + self.bucketing_global_state.prompt_buckets, True, + kv_caches, + graph_free_mem - mem_post_prompt - mem_post_decode, + mem_post_prompt, prompt_batch_seq)) + + # Not all decode buckets were captured, but all prompt buckets + # were captured and we have some free graph-allocated space + # left. Let's try to use it for capturing more decode buckets. + if mem_post_decode + mem_post_prompt < graph_free_mem \ + and not decode_captured_all \ + and prompt_captured_all: + mem_post_decode, _, _ = self.warmup_graphs( + decode_strategy, + self.bucketing_global_state.decode_buckets, False, + kv_caches, + graph_free_mem - mem_post_prompt - mem_post_decode, + mem_post_decode, decode_batch_seq) + + self.log_graph_warmup_summary( + self.bucketing_global_state.prompt_buckets, True, + mem_post_prompt) + self.log_graph_warmup_summary( + self.bucketing_global_state.decode_buckets, False, + mem_post_decode) + + end_time = time.perf_counter() + end_mem = HabanaMemoryProfiler.current_device_memory_usage() + elapsed_time = end_time - start_time + msg = ( + f"Warmup finished in {elapsed_time:.0f} secs, " + f"allocated {format_bytes(end_mem - start_mem)} of device memory") + logger.info(msg) + self.profiler.end() + + @property + def vocab_size(self) -> int: + return self.model_config.get_vocab_size() + + @property + def mem_margin(self) -> Optional[int]: + return self._mem_margin + + @mem_margin.setter + def mem_margin(self, value): + self._mem_margin = value + + +def _maybe_wrap_in_hpu_graph(*args, **kwargs): + return htorch.hpu.wrap_in_hpu_graph( + HpuModelAdapter(*args, **kwargs), disable_tensor_cache=True + ) if htorch.utils.internal.is_lazy() else HpuModelAdapter(*args, **kwargs) + + +class HabanaProfilerCounterHelper(): + + def __init__(self): + self.niter = 0 + self.average_real_throughput = None + self.logged_once = False + self.real_seq_lens = [] + self.prompt_seq_lens = [] + + def capture_seq_group_metadata_stats(self, seq_group_metadata_list): + self.real_seq_lens = [ + len(seq_data.prompt_token_ids) + len(seq_data.output_token_ids) + for seq_group_metadata in seq_group_metadata_list + for seq_data in seq_group_metadata.seq_data.values() + ] + self.prompt_seq_lens = [ + len(seq_data.prompt_token_ids) + for seq_group_metadata in seq_group_metadata_list + for seq_data in seq_group_metadata.seq_data.values() + ] + + def get_counter_dict(self, cache_config, duration, seq_len, + batch_size_padded, real_batch_size, is_prompt): + throughput = batch_size_padded / (duration / 1e6) + throughput_effective = real_batch_size / (duration / 1e6) + + real_max_seq_len = max(self.real_seq_lens) + real_num_tokens = sum(self.real_seq_lens) + padded_num_tokens = batch_size_padded * seq_len + batch_token_utilization = real_num_tokens / padded_num_tokens + if self.average_real_throughput is None: + self.average_real_throughput = throughput_effective + else: # https://www.heikohoffmann.de/htmlthesis/node134.html + self.average_real_throughput = self.average_real_throughput + 1 / ( + self.niter + 1) * (throughput_effective - + self.average_real_throughput) + phase = "prompt" if is_prompt else "decode" + counters = { + f'{phase}_bucket_batch_size': batch_size_padded, + f'{phase}_batch_size': real_batch_size, + f'{phase}_bucket_seq_len': seq_len, + f'{phase}_seq_len': real_max_seq_len, + f'{phase}_bucket_gen_throughput': throughput, + f'{phase}_real_gen_throughput': throughput_effective, + f'{phase}_batch_token_utilization': batch_token_utilization, + 'average_real_throughput': self.average_real_throughput, + 'engine_iteration': self.niter, + } + self.niter += 1 + if is_prompt: + prompt_bucket_in_throughput = (seq_len * batch_size_padded) / ( + duration / 1e6) + prompt_real_in_throughput = sum( + self.prompt_seq_lens) / (duration / 1e6) + counters[ + f'{phase}_bucket_in_throughput'] = prompt_bucket_in_throughput + counters[f'{phase}_real_in_throughput'] = prompt_real_in_throughput + + # KV cache might not be created yet (e.g. for profiling run) + if cache_config.num_gpu_blocks is not None and \ + cache_config.num_gpu_blocks != 0: + cache_num_blocks_used = [ + math.ceil(sl / cache_config.block_size) + for sl in self.real_seq_lens + ] + cache_total_num_blocks_used = sum(cache_num_blocks_used) + num_cache_blocks = cache_config.num_gpu_blocks + cache_total_num_free_blocks = \ + num_cache_blocks - cache_total_num_blocks_used + cache_computed_utilization = \ + cache_total_num_blocks_used / num_cache_blocks + max_blocks_per_seq = math.ceil(seq_len / cache_config.block_size) + batch_block_utilization = cache_total_num_blocks_used / ( + batch_size_padded * max_blocks_per_seq) + counters['cache_num_blocks_used'] = cache_total_num_blocks_used + counters['cache_num_free_blocks'] = cache_total_num_free_blocks + counters['cache_computed_utilization'] = cache_computed_utilization + counters[ + f'{phase}_batch_block_utilization'] = batch_block_utilization + if not self.logged_once: + counters['const_cache_num_blocks'] = cache_config.num_gpu_blocks + counters[ + 'const_gpu_memory_utilization'] = \ + cache_config.gpu_memory_utilization + counters['const_block_size'] = cache_config.block_size + self.logged_once = True + return counters + + +def unwrap_model(model): + if isinstance(model, torch._dynamo.eval_frame.OptimizedModule): + return unwrap_model(model._orig_mod) + else: + model = list(vars(model)['_modules'].values())[0] + modules = list(vars(model)['_modules'].values()) + return modules + + +class HPUModelRunner(HPUModelRunnerBase[ModelInputForHPUWithSamplingMetadata]): + """ + GPU model runner with sampling step. + """ + _model_input_cls: Type[ModelInputForHPUWithSamplingMetadata] = ( + ModelInputForHPUWithSamplingMetadata) + + def make_model_input_from_broadcasted_tensor_dict( + self, + tensor_dict: Dict[str, Any], + ) -> ModelInputForHPUWithSamplingMetadata: + return ( + ModelInputForHPUWithSamplingMetadata.from_broadcasted_tensor_dict( + tensor_dict, + attn_backend=self.attn_backend, + )) + + @torch.inference_mode() + def prepare_model_input( + self, + seq_group_metadata_list: List[SequenceGroupMetadata], + virtual_engine: int = 0, + finished_requests_ids: Optional[List[str]] = None + ) -> ModelInputForHPUWithSamplingMetadata: + """Prepare the model input based on a given sequence group, including + metadata for the sampling step. + The API assumes seq_group_metadata_list is sorted by prefill -> decode. + The result tensors and data structure also batches input in prefill + -> decode order. For example, + - input_tokens[:num_prefill_tokens] contains prefill tokens. + - input_tokens[num_prefill_tokens:] contains decode tokens. + If cuda graph is required, this API automatically pads inputs. + """ + with self.profiler.record_event('internal', 'prepare_input_tensors'): + assert seq_group_metadata_list is not None + if self.profiler.enabled: + self.profiler_counter_helper.capture_seq_group_metadata_stats( + seq_group_metadata_list=seq_group_metadata_list) + model_input, sampling_metadata = self.prepare_input_tensors( + seq_group_metadata_list) + assert model_input.attn_metadata is not None + is_prompt = model_input.attn_metadata.is_prompt + + return dataclasses.replace(model_input, + sampling_metadata=sampling_metadata, + is_prompt=is_prompt, + virtual_engine=virtual_engine) + + def finish_measurements(self): + from neural_compressor.torch.quantization import finalize_calibration + finalize_calibration(self.model.model) + + def _check_config(self, batch_size, seq_len, is_prompt, warmup_mode): + cfg = (batch_size, seq_len, is_prompt) + seen = cfg in self.seen_configs + self.seen_configs.add(cfg) + if not seen and not warmup_mode: + phase = 'prompt' if is_prompt else 'decode' + logger.warning("Configuration: (%s, %s, %s) was not warmed-up!", + phase, batch_size, seq_len) + + def create_lora_mask(self, input_tokens: torch.Tensor, lora_ids: List[int], + is_prompt: bool): + ''' + This is a helper function to create the mask for lora computations. + Lora Mask is needed to ensure we match the correct lora weights for the + for the request. + For Prompt phase we have + lora_mask with shape (batch_size * seq_len, max_loras * max_rank) + lora_logits_mask with shape (batch_size, max_loras * max_rank) + For Decode phase we have both + lora_mask and lora_logits_mask with shape + (batch_size, max_loras * max_rank) + ''' + lora_mask: torch.Tensor = None + lora_logits_mask: torch.Tensor = None + lora_index = 0 + + if self.lora_config: + if is_prompt: + lora_mask = torch.zeros( + input_tokens.shape[0] * input_tokens.shape[1], + (self.lora_config.max_loras) *\ + self.lora_config.max_lora_rank, + dtype=self.lora_config.lora_dtype) + lora_logits_mask = torch.zeros( + input_tokens.shape[0], (self.lora_config.max_loras) * + self.lora_config.max_lora_rank, + dtype=self.lora_config.lora_dtype) + + ones = torch.ones(input_tokens.shape[1], + self.lora_config.max_lora_rank, + dtype=self.lora_config.lora_dtype) + logit_ones = torch.ones(1, + self.lora_config.max_lora_rank, + dtype=self.lora_config.lora_dtype) + + for i in range(len(lora_ids)): + if lora_ids[i] == 0: + continue + lora_index = self.lora_manager._adapter_manager.\ + lora_index_to_id.index(lora_ids[i]) + start_row = i * input_tokens.shape[1] + end_row = start_row + input_tokens.shape[1] + start_col = lora_index * self.lora_config.max_lora_rank + end_col = start_col + self.lora_config.max_lora_rank + lora_mask[start_row:end_row, start_col:end_col] = ones + lora_logits_mask[i, start_col:end_col] = logit_ones + lora_mask = lora_mask.to('hpu') + lora_logits_mask = lora_logits_mask.to('hpu') + else: + lora_mask = torch.zeros(input_tokens.shape[0], + (self.lora_config.max_loras) * + self.lora_config.max_lora_rank, + dtype=self.lora_config.lora_dtype) + ones = torch.ones(1, + self.lora_config.max_lora_rank, + dtype=self.lora_config.lora_dtype) + for i in range(len(lora_ids)): + if lora_ids[i] == 0: + continue + lora_index = self.lora_manager._adapter_manager.\ + lora_index_to_id.index(lora_ids[i]) + start_pos = lora_index * self.lora_config.max_lora_rank + end_pos = start_pos + self.lora_config.max_lora_rank + lora_mask[i, start_pos:end_pos] = ones + lora_mask = lora_mask.to('hpu') + lora_logits_mask = lora_mask + + return lora_mask, lora_logits_mask + + @torch.inference_mode() + def execute_model( + self, + model_input: ModelInputForHPUWithSamplingMetadata, + kv_caches: List[torch.Tensor], + intermediate_tensors: Optional[IntermediateTensors] = None, + num_steps: int = 1, + warmup_mode=False, + ) -> Optional[Union[List[SamplerOutput], IntermediateTensors]]: + if num_steps > 1: + raise ValueError( + "num_steps > 1 is not supported in HPUModelRunner") + + if self.lora_config: + assert model_input.lora_requests is not None + assert model_input.lora_mapping is not None + self.set_active_loras(model_input.lora_requests, + model_input.lora_mapping) + input_tokens = model_input.input_tokens + input_positions = model_input.input_positions + attn_metadata = model_input.attn_metadata + sampling_metadata = model_input.sampling_metadata + real_batch_size = model_input.real_batch_size + batch_size_padded = model_input.batch_size_padded + assert input_tokens is not None + assert input_positions is not None + assert sampling_metadata is not None + assert attn_metadata is not None + is_prompt = attn_metadata.is_prompt + assert is_prompt is not None + batch_size = input_tokens.size(0) + seq_len = self._seq_len(attn_metadata) + use_graphs = self._use_graphs(batch_size, seq_len, is_prompt) + self._check_config(batch_size, seq_len, is_prompt, warmup_mode) + + lora_mask: torch.Tensor = None + lora_logits_mask: torch.Tensor = None + if self.lora_config: + assert model_input.lora_ids is not None + lora_mask, lora_logits_mask = self.create_lora_mask( + input_tokens, model_input.lora_ids, attn_metadata.is_prompt) + + execute_model_kwargs = { + "input_ids": input_tokens, + "positions": input_positions, + "kv_caches": kv_caches, + "attn_metadata": self.trim_attn_metadata(attn_metadata), + "intermediate_tensors": intermediate_tensors, + "lora_mask": lora_mask, + **(model_input.multi_modal_kwargs or {}), + } + if htorch.utils.internal.is_lazy(): + execute_model_kwargs.update({"bypass_hpu_graphs": not use_graphs}) + + htorch.core.mark_step() + if self.is_driver_worker: + model_event_name = ("model_" + f"{'prompt' if is_prompt else 'decode'}_" + f"bs{batch_size}_" + f"seq{seq_len}_" + f"graphs{'T' if use_graphs else 'F'}") + else: + model_event_name = 'model_executable' + with self.profiler.record_event('internal', model_event_name): + hidden_states = self.model.forward( + **execute_model_kwargs, + selected_token_indices=sampling_metadata.selected_token_indices + ) + + if self.lora_config: + LoraMask.setLoraMask( + lora_logits_mask.index_select( + 0, sampling_metadata.selected_token_indices)) + + # Compute the logits. + with self.profiler.record_event( + 'internal', ('compute_logits_' + f'{"prompt" if is_prompt else "decode"}_bs' + f'{batch_size}_' + f'seq{seq_len}')): + sampling_metadata.selected_token_indices = None + logits = self.model.compute_logits(hidden_states, + sampling_metadata) + htorch.core.mark_step() + # Only perform sampling in the driver worker. + if not self.is_driver_worker: + return [] + + if model_input.async_callback is not None: + model_input.async_callback() + + # Sample the next token. + with self.profiler.record_event( + 'internal', ('sample_' + f'{"prompt" if is_prompt else "decode"}_' + f'bs{batch_size}_' + f'seq{seq_len}')): + output = self.model.sample( + logits=logits, + sampling_metadata=sampling_metadata, + ) + output.outputs = output.outputs[:real_batch_size] + htorch.core.mark_step() + + if self.is_driver_worker and self.profiler.enabled: + # Stop recording 'execute_model' event + self.profiler.end() + event_end = self.profiler.get_timestamp_us() + counters = self.profiler_counter_helper.get_counter_dict( + cache_config=self.cache_config, + duration=event_end - self.event_start, + seq_len=seq_len, + batch_size_padded=batch_size_padded, + real_batch_size=real_batch_size, + is_prompt=is_prompt) + self.profiler.record_counter(self.event_start, counters) + return [output] + + def shutdown_inc(self): + can_finalize_inc = False + from contextlib import suppress + with suppress(AttributeError): + can_finalize_inc = (self.model_config.quantization == 'inc') and \ + (self.model.model is not None) and \ + self.inc_initialized_successfully and \ + not getattr(self, "_is_inc_finalized", False) + if can_finalize_inc: + from neural_compressor.torch.quantization import ( + finalize_calibration) + finalize_calibration(self.model.model) + self._is_inc_finalized = True + + def __del__(self): + self.shutdown_inc() diff --git a/vllm/worker/hpu_worker.py b/vllm/worker/hpu_worker.py new file mode 100644 index 0000000000000..e33926aa0b7fd --- /dev/null +++ b/vllm/worker/hpu_worker.py @@ -0,0 +1,438 @@ +############################################################################### +# Copyright (C) 2024 Habana Labs, Ltd. an Intel Company +############################################################################### + +import gc +import os +from typing import List, Optional, Set, Tuple, Type + +import habana_frameworks.torch as htorch # noqa:F401 +import torch +import torch.distributed +from vllm_hpu_extension.profiler import HabanaMemoryProfiler, format_bytes + +import vllm.envs as envs +from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, + ModelConfig, ObservabilityConfig, ParallelConfig, + PromptAdapterConfig, SchedulerConfig, + SpeculativeConfig) +from vllm.distributed import (ensure_model_parallel_initialized, + init_distributed_environment) +from vllm.logger import init_logger +from vllm.lora.request import LoRARequest +from vllm.model_executor import set_random_seed +from vllm.prompt_adapter.request import PromptAdapterRequest +from vllm.sequence import ExecuteModelRequest +from vllm.worker.cache_engine import CacheEngine +from vllm.worker.hpu_model_runner import HPUModelRunner +from vllm.worker.model_runner_base import ModelRunnerBase +from vllm.worker.worker_base import LocalOrDistributedWorkerBase, WorkerInput + +logger = init_logger(__name__) + + +class HPUWorker(LocalOrDistributedWorkerBase): + """A worker class that executes (a partition of) the model on a HPU. + + Each worker is associated with a single HPU. The worker is responsible for + maintaining the KV cache and executing the model on the HPU. In case of + distributed inference, each worker is assigned a partition of the model. + """ + + def __init__( + self, + model_config: ModelConfig, + parallel_config: ParallelConfig, + scheduler_config: SchedulerConfig, + device_config: DeviceConfig, + cache_config: CacheConfig, + load_config: LoadConfig, + local_rank: int, + rank: int, + distributed_init_method: str, + lora_config: Optional[LoRAConfig] = None, + speculative_config: Optional[SpeculativeConfig] = None, + prompt_adapter_config: Optional[PromptAdapterConfig] = None, + is_driver_worker: bool = False, + model_runner_cls: Optional[Type[ModelRunnerBase]] = None, + observability_config: Optional[ObservabilityConfig] = None, + ) -> None: + self.model_config = model_config + self.parallel_config = parallel_config + self.parallel_config.rank = rank + self.scheduler_config = scheduler_config + self.device_config = device_config + self.cache_config = cache_config + self.local_rank = local_rank + self.rank = rank + self.distributed_init_method = distributed_init_method + self.lora_config = lora_config + self.load_config = load_config + self.prompt_adapter_config = prompt_adapter_config + self.is_driver_worker = is_driver_worker + if self.is_driver_worker: + assert self.rank == 0, "The driver worker must have rank 0." + + if self.model_config.trust_remote_code: + # note: lazy import to avoid importing torch before initializing + from vllm.utils import init_cached_hf_modules + init_cached_hf_modules() + + self.model_runner: HPUModelRunner = HPUModelRunner( + model_config, + parallel_config, + scheduler_config, + device_config, + cache_config, + load_config=load_config, + lora_config=self.lora_config, + kv_cache_dtype=self.cache_config.cache_dtype, + is_driver_worker=is_driver_worker, + prompt_adapter_config=prompt_adapter_config, + observability_config=observability_config) + # Uninitialized cache engine. Will be initialized by + # initialize_cache. + self.cache_engine: List[HPUCacheEngine] + # Initialize gpu_cache as embedding models don't initialize kv_caches + self.hpu_cache: Optional[List[List[torch.tensor]]] = None + # Torch profiler. Enabled and configured through env vars: + # VLLM_TORCH_PROFILER_DIR=/path/to/save/trace + if envs.VLLM_TORCH_PROFILER_DIR: + torch_profiler_trace_dir = envs.VLLM_TORCH_PROFILER_DIR + logger.info("Profiling enabled. Traces will be saved to: %s", + torch_profiler_trace_dir) + self.profiler = torch.profiler.profile( + activities=[ + torch.profiler.ProfilerActivity.CPU, + torch.profiler.ProfilerActivity.HPU, + ], + with_stack=True, + on_trace_ready=torch.profiler.tensorboard_trace_handler( + torch_profiler_trace_dir, use_gzip=True)) + else: + self.profiler = None + + def start_profile(self): + if self.profiler is None: + raise RuntimeError("Profiler is not enabled.") + self.profiler.start() + + def stop_profile(self): + if self.profiler is None: + raise RuntimeError("Profiler is not enabled.") + self.profiler.stop() + + def _set_env_vars(self): + local_rank = self.local_rank + if self.parallel_config.world_size == 1: + local_rank = -1 + import os + os.environ["LOCAL_RANK"] = str(local_rank) + os.environ["ID"] = str(local_rank) + os.environ["WORLD_SIZE"] = str(self.parallel_config.world_size) + os.environ["RANK"] = str(self.rank) + + def init_device(self) -> None: + if self.device_config.device.type == "hpu": + self.device = torch.device("hpu") + torch.hpu.set_device(self.device) + else: + raise RuntimeError( + f"Not support device type: {self.device_config.device}") + # Initialize the distributed environment. + if self.model_config.quantization == 'inc': + self._set_env_vars() + init_worker_distributed_environment(self.parallel_config, self.rank, + self.distributed_init_method, + self.local_rank) + # Set random seed. + set_random_seed(self.model_config.seed) + + def load_model(self): + self.model_runner.load_model() + + @torch.inference_mode() + def determine_num_available_blocks(self) -> Tuple[int, int]: + """Profiles the peak memory usage of the model to determine how many + KV blocks may be allocated without OOMs. + + The engine will first conduct a profiling of the existing memory usage. + Then, it calculate the maximum possible number of GPU and CPU blocks + that can be allocated with the remaining free memory. + + .. tip:: + You may limit the usage of GPU memory + by adjusting the `gpu_memory_utilization` parameter. + """ + # Profile the memory usage of the model and get the maximum number of + # cache blocks that can be allocated with the remaining free memory. + + # Execute a forward pass with dummy inputs to profile the memory usage + # of the model. + with HabanaMemoryProfiler() as m: + self.model_runner.profile_run() + torch.hpu.synchronize() + msg = ("Model profiling run " + f"took {m.get_summary_string()}") + logger.info(msg) + # At this point we should've allocated the maximum workspace for all + # recipes we will use the extra memory for graphs/blocks + free_hpu_memory = torch.hpu.mem_get_info()[0] + + cache_block_size = self.get_cache_block_size_bytes() + graph_reserved_mem = (float( + os.environ.get('VLLM_GRAPH_RESERVED_MEM', '0.1')) + if not self.model_config.enforce_eager else 0) + graph_headroom = 1 - graph_reserved_mem + available_hpu_memory = free_hpu_memory * \ + self.cache_config.gpu_memory_utilization + hpu_memory_margin = free_hpu_memory * ( + 1 - self.cache_config.gpu_memory_utilization) + self.model_runner.mem_margin = hpu_memory_margin + cache_size_bytes = available_hpu_memory * graph_headroom + graph_headroom_bytes = available_hpu_memory * (1 - graph_headroom) + msg = ( + f"Free device memory: {format_bytes(free_hpu_memory)}, " + f"{format_bytes(available_hpu_memory)} usable " + f"(gpu_memory_utilization={self.cache_config.gpu_memory_utilization})," + f" {format_bytes(graph_headroom_bytes)} reserved for HPUGraphs " + f"(VLLM_GRAPH_RESERVED_MEM={graph_reserved_mem}), " + f"{format_bytes(cache_size_bytes)} reserved for KV cache") + logger.info(msg) + num_hpu_blocks = int(cache_size_bytes // cache_block_size) + num_cpu_blocks = int(self.cache_config.swap_space_bytes // + cache_block_size) + num_hpu_blocks = max(num_hpu_blocks, 0) + num_cpu_blocks = max(num_cpu_blocks, 0) + + if self.model_runner.lora_manager: + self.model_runner.remove_all_loras() + + gc.collect() + return num_hpu_blocks, num_cpu_blocks + + def initialize_cache(self, num_gpu_blocks: int, + num_cpu_blocks: int) -> None: + """Allocate GPU and CPU KV cache with the specified number of blocks. + + This also warms up the model, which may record CUDA graphs. + """ + raise_if_cache_size_invalid(num_gpu_blocks, + self.cache_config.block_size, + self.model_config.max_model_len) + + self.cache_config.num_gpu_blocks = num_gpu_blocks + self.cache_config.num_cpu_blocks = num_cpu_blocks + + with HabanaMemoryProfiler() as m: + self._init_cache_engine() + torch.hpu.synchronize() + msg = ("Initializing cache engine " + f"took {m.get_summary_string()}") + logger.info(msg) + self._warm_up_model() + + def _init_cache_engine(self): + assert self.cache_config.num_gpu_blocks is not None + self.cache_engine = [ + HPUCacheEngine(self.cache_config, self.model_config, + self.parallel_config, self.device_config) + for _ in range(self.parallel_config.pipeline_parallel_size) + ] + self.hpu_cache = [ + self.cache_engine[ve].gpu_cache + for ve in range(self.parallel_config.pipeline_parallel_size) + ] + + def _warm_up_model(self) -> None: + # NOTE(kzawora): We should use virtual engine index here + # for pipeline parallelism. Using 0 for now. + assert self.hpu_cache is not None + self.model_runner.warmup_model(self.hpu_cache[0]) + # Reset the seed to ensure that the random state is not affected by + # the model initialization and profiling. + set_random_seed(self.model_config.seed) + + def finish_measurements(self): + self.model_runner.finish_measurements() + + @property + def do_metadata_broadcast(self) -> bool: + return self.parallel_config.tensor_parallel_size > 1 + + @property + def kv_cache(self) -> Optional[List[List[torch.Tensor]]]: + return self.hpu_cache + + @torch.inference_mode() + def prepare_worker_input( + self, execute_model_req: ExecuteModelRequest) -> WorkerInput: + virtual_engine = execute_model_req.virtual_engine + num_seq_groups = len(execute_model_req.seq_group_metadata_list) + # `blocks_to_swap_in` and `blocks_to_swap_out` are cpu tensors. + # they contain parameters to launch cudamemcpyasync. + blocks_to_swap_in = torch.tensor(execute_model_req.blocks_to_swap_in, + device="cpu", + dtype=torch.int64).view(-1, 2) + blocks_to_swap_out = torch.tensor(execute_model_req.blocks_to_swap_out, + device="cpu", + dtype=torch.int64).view(-1, 2) + # `blocks_to_copy` is a gpu tensor. The src and tgt of + # blocks to copy are in the same device, and `blocks_to_copy` + # can be used directly within cuda kernels. + blocks_to_copy = torch.tensor(execute_model_req.blocks_to_copy, + device=self.device, + dtype=torch.int64).view(-1, 2) + + return WorkerInput( + num_seq_groups=num_seq_groups, + blocks_to_swap_in=blocks_to_swap_in, + blocks_to_swap_out=blocks_to_swap_out, + blocks_to_copy=blocks_to_copy, + virtual_engine=virtual_engine, + ) + + @torch.inference_mode() + def execute_worker(self, worker_input: WorkerInput) -> None: + virtual_engine = worker_input.virtual_engine + # Issue cache operations. + if (worker_input.blocks_to_swap_in is not None + and worker_input.blocks_to_swap_in.numel() > 0): + self.cache_engine[virtual_engine].swap_in( + worker_input.blocks_to_swap_in) + if (worker_input.blocks_to_swap_out is not None + and worker_input.blocks_to_swap_out.numel() > 0): + self.cache_engine[virtual_engine].swap_out( + worker_input.blocks_to_swap_out) + if (worker_input.blocks_to_copy is not None + and worker_input.blocks_to_copy.numel() > 0): + self.cache_engine[virtual_engine].copy(worker_input.blocks_to_copy) + + def add_lora(self, lora_request: LoRARequest) -> bool: + return self.model_runner.add_lora(lora_request) + + def remove_lora(self, lora_id: int) -> bool: + return self.model_runner.remove_lora(lora_id) + + def pin_lora(self, lora_id: int) -> bool: + return self.model_runner.pin_lora(lora_id) + + def list_loras(self) -> Set[int]: + return self.model_runner.list_loras() + + def add_prompt_adapter( + self, prompt_adapter_request: PromptAdapterRequest) -> bool: + raise NotImplementedError( + "Prompt Adapter is not implemented for HPU backend.") + + def remove_prompt_adapter(self, prompt_adapter_id: int) -> bool: + raise NotImplementedError( + "Prompt Adapter is not implemented for HPU backend.") + + def pin_prompt_adapter(self, prompt_adapter_id: int) -> bool: + raise NotImplementedError( + "Prompt Adapter is not implemented for HPU backend.") + + def list_prompt_adapters(self) -> Set[int]: + raise NotImplementedError( + "Prompt Adapter is not implemented for HPU backend.") + + def shutdown_inc(self): + self.model_runner.shutdown_inc() + + @property + def max_model_len(self) -> int: + return self.model_config.max_model_len + + @property + def vocab_size(self) -> int: + return self.model_runner.vocab_size + + def get_cache_block_size_bytes(self) -> int: + """Get the size of the KV cache block size in bytes. + """ + return HPUCacheEngine.get_cache_block_size(self.cache_config, + self.model_config, + self.parallel_config) + + +def init_worker_distributed_environment( + parallel_config: ParallelConfig, + rank: int, + distributed_init_method: Optional[str] = None, + local_rank: int = -1, +) -> None: + """Initialize the distributed environment.""" + init_distributed_environment(parallel_config.world_size, + rank, + distributed_init_method, + local_rank, + backend='hccl') + + ensure_model_parallel_initialized(parallel_config.tensor_parallel_size, + parallel_config.pipeline_parallel_size) + + if torch.distributed.is_initialized(): + torch_world_size = torch.distributed.get_world_size() + if torch_world_size != parallel_config.world_size: + raise RuntimeError( + "torch.distributed is already initialized but the torch world " + "size does not match parallel_config.world_size " + f"({torch_world_size} vs. {parallel_config.world_size}).") + elif not distributed_init_method: + raise ValueError( + "distributed_init_method must be set if torch.distributed " + "is not already initialized") + else: + torch.distributed.init_process_group( + backend="hccl", + world_size=parallel_config.world_size, + rank=rank, + init_method=distributed_init_method, + ) + + # A small all_reduce for warmup & checking conformance. + dummy_tensor_hpu = torch.ones(1).to('hpu') + torch.distributed.all_reduce(dummy_tensor_hpu) + assert dummy_tensor_hpu.item() == parallel_config.world_size + ensure_model_parallel_initialized(parallel_config.tensor_parallel_size, + parallel_config.pipeline_parallel_size) + + +def raise_if_cache_size_invalid(num_gpu_blocks, block_size, + max_model_len) -> None: + if num_gpu_blocks <= 0: + raise ValueError("No available memory for the cache blocks. " + "Try increasing `gpu_memory_utilization` when " + "initializing the engine.") + max_seq_len = block_size * num_gpu_blocks + if max_model_len > max_seq_len: + raise ValueError( + f"The model's max seq len ({max_model_len}) " + "is larger than the maximum number of tokens that can be " + f"stored in KV cache ({max_seq_len}). Try increasing " + "`gpu_memory_utilization` or decreasing `max_model_len` when " + "initializing the engine.") + + +class HPUCacheEngine(CacheEngine): + + def _allocate_kv_cache( + self, + num_blocks: int, + device: str, + ) -> List[Tuple[torch.Tensor, torch.Tensor]]: + """Allocates KV cache on the specified device.""" + kv_cache_shape = self.attn_backend.get_kv_cache_shape( + num_blocks, self.block_size, self.num_kv_heads, self.head_size) + kv_cache: List[Tuple[torch.Tensor, torch.Tensor]] = [] + for _ in range(self.num_attention_layers): + key_cache = torch.zeros(kv_cache_shape, + dtype=self.dtype, + device=device) + value_cache = torch.zeros(kv_cache_shape, + dtype=self.dtype, + device=device) + kv_layer = (key_cache, value_cache) + kv_cache.append(kv_layer) + return kv_cache diff --git a/vllm/worker/model_runner_base.py b/vllm/worker/model_runner_base.py index 86883cf152449..84c1c8ed42a90 100644 --- a/vllm/worker/model_runner_base.py +++ b/vllm/worker/model_runner_base.py @@ -46,10 +46,15 @@ def _init_attn_metadata_from_tensor_dict( # Extract the fields used to create AttentionMetadata. valid_attn_kwargs = {} for field in dataclasses.fields(attn_backend.get_metadata_cls()): - val = tensor_dict.pop(field.name, None) - if val is not None: + # NOTE(kzawora): We use sentinel here, as None + # may be a valid value if type is optional. If + # we don't check against it, we will crash by not assigning + # Optional types without default value, even if they are + # broadcasted properly. + sentinel = object() + val = tensor_dict.pop(field.name, sentinel) + if val != sentinel: valid_attn_kwargs[field.name] = val - attn_metadata = attn_backend.make_metadata(**valid_attn_kwargs) tensor_dict["attn_metadata"] = attn_metadata return tensor_dict