Skip to content

Commit

Permalink
vLLM-Base: Full enabling of ALiBi
Browse files Browse the repository at this point in the history
Changes:
- Added back alibi biases to decode stage.
- Optimized ALiBI memory usage.
  - Added environment variable "VLLM_PROMPT_ALIBI_MAX_SEQ_LEN" to allow
    large models to run with restricted prompt lengths.
  - Prompt biases instantiated once in __init__ rather than each
    forward.
  - Prompt and decode biases are shared across encoder/decoder layers.
- Added environment variable "VLLM_ALIBI_USE_FLOAT32_BIASES" to resolve
  accuracy issue on long sequences.
- Updated jais, mpt, falcon, baichuan, and bloom to work with ALiBI.
  - Due to bloom's 176B parameter size I was unable to test this model.
    Its changes are the simplest though.
- Works in lazy and eager mode.
- ALiBI is restricted to "VLLM_PROMPT_USE_FUSEDSDPA=false", and
  "VLLM_CONTIGUOUS_PA=true".
- Add position offsets to improve quality on BS > 1 with sequences of
  varying length.
- BS > 1 may have accuracy issues if on FW < 1.19.0. This is due to
  limitation in softmax. Resolved on FW >= 1.19.0.
- NTT patch for GQA

Co-authored-by: Tanner Voas <[email protected]>
Co-authored-by: Haihao Xiang <[email protected]>
Signed-off-by: Tanner Voas <[email protected]>
  • Loading branch information
tannervoas742 and xhaihao committed Dec 10, 2024
1 parent def7ac2 commit 64822b0
Show file tree
Hide file tree
Showing 10 changed files with 328 additions and 100 deletions.
135 changes: 114 additions & 21 deletions vllm/attention/backends/hpu_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,8 @@ def __init__(
kv_cache_dtype: str,
blocksparse_params: Optional[Dict[str, Any]] = None,
max_seq_len: int = 4096,
tp_rank: Optional[int] = None,
prev_attn: Optional[Any] = None,
) -> None:
super(AttentionImpl, self).__init__()
self.kv_cache_dtype = kv_cache_dtype
Expand All @@ -142,11 +144,38 @@ def __init__(
else ModuleFusedSDPA(HPUFusedSDPA)
self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads
self.sliding_window = sliding_window
self.alibi_slopes = alibi_slopes
self.alibi_slopes = None
self.prompt_position_bias = None
# Set upper bound on sequence length
self.max_seq_len = int(os.getenv(
'VLLM_PROMPT_ALIBI_MAX_SEQ_LEN',
max_seq_len,
))
# Set lower bound on sequence length
self.max_seq_len = max([
self.max_seq_len,
int(os.getenv('VLLM_PROMPT_SEQ_BUCKET_MAX', '0')),
])
self.tp_rank = tp_rank
self.prev_attn = None if prev_attn is None else prev_attn.impl
if alibi_slopes is not None:
alibi_slopes_tensor = torch.tensor(alibi_slopes,
dtype=torch.bfloat16)
self.alibi_slopes = alibi_slopes_tensor
if self.prev_attn is not None and self.prev_attn.tp_rank == self.tp_rank:

Check failure on line 162 in vllm/attention/backends/hpu_attn.py

View workflow job for this annotation

GitHub Actions / ruff (3.12)

Ruff (E501)

vllm/attention/backends/hpu_attn.py:162:81: E501 Line too long (85 > 80)
self.alibi_slopes = self.prev_attn.alibi_slopes
self.prompt_position_bias = self.prev_attn.prompt_position_bias
else:
slope_tensor_dtype = {
True: torch.float32,
False: torch.bfloat16,
}[os.getenv('VLLM_ALIBI_USE_FLOAT32_BIASES', '1').lower() in ['1', 'true']]

Check failure on line 169 in vllm/attention/backends/hpu_attn.py

View workflow job for this annotation

GitHub Actions / ruff (3.12)

Ruff (E501)

vllm/attention/backends/hpu_attn.py:169:81: E501 Line too long (91 > 80)
alibi_slopes_tensor = torch.tensor(alibi_slopes,
dtype=slope_tensor_dtype)
self.alibi_slopes = alibi_slopes_tensor
self.prompt_position_bias = _make_prompt_alibi_bias(
alibi_slopes=self.alibi_slopes,
seq_len=self.max_seq_len,
num_kv_heads=self.num_kv_heads,
dtype=self.alibi_slopes.dtype,
)
assert self.num_heads % self.num_kv_heads == 0
self.num_queries_per_kv = self.num_heads // self.num_kv_heads

Expand All @@ -157,6 +186,12 @@ def __init__(
assert alibi_slopes is None, \
'Prefill with FusedSDPA not supported with alibi slopes!'

self.use_contiguous_pa = os.environ.get('VLLM_CONTIGUOUS_PA',
'true').lower() == 'true'
if not self.use_contiguous_pa:
assert alibi_slopes is None, \
'Non-contiguous PA not supported with alibi slopes!'

suppored_head_sizes = HPUPagedAttention.get_supported_head_sizes()
if head_size not in suppored_head_sizes:
raise ValueError(
Expand Down Expand Up @@ -236,21 +271,38 @@ def forward(
assert attn_metadata.attn_bias is not None, \
'attn_bias must be set before calling model.forward'
attn_bias = attn_metadata.attn_bias
seq_lens_tensor = attn_metadata.seq_lens_tensor
position_bias = None
position_bias_offset = None
if self.alibi_slopes is not None:
position_bias = _make_alibi_bias(
self.alibi_slopes, self.num_kv_heads,
attn_bias.dtype, attn_bias.shape[-1])
attn_bias = attn_bias.tile(
(1, self.num_kv_heads, 1, 1))
attn_bias.add_(position_bias)
if self.max_seq_len >= max(attn_bias.size(-2), attn_bias.size(-1)):

Check failure on line 278 in vllm/attention/backends/hpu_attn.py

View workflow job for this annotation

GitHub Actions / ruff (3.12)

Ruff (E501)

vllm/attention/backends/hpu_attn.py:278:81: E501 Line too long (91 > 80)
position_bias = self.prompt_position_bias[:, :, -attn_bias.size(-2):, -attn_bias.size(-1):]

Check failure on line 279 in vllm/attention/backends/hpu_attn.py

View workflow job for this annotation

GitHub Actions / mypy (3.9)

Value of type "Optional[Any]" is not indexable [index]

Check failure on line 279 in vllm/attention/backends/hpu_attn.py

View workflow job for this annotation

GitHub Actions / ruff (3.12)

Ruff (E501)

vllm/attention/backends/hpu_attn.py:279:81: E501 Line too long (119 > 80)

Check failure on line 279 in vllm/attention/backends/hpu_attn.py

View workflow job for this annotation

GitHub Actions / mypy (3.10)

Value of type "Any | None" is not indexable [index]

Check failure on line 279 in vllm/attention/backends/hpu_attn.py

View workflow job for this annotation

GitHub Actions / mypy (3.11)

Value of type "Any | None" is not indexable [index]

Check failure on line 279 in vllm/attention/backends/hpu_attn.py

View workflow job for this annotation

GitHub Actions / mypy (3.12)

Value of type "Any | None" is not indexable [index]
else:
# Create new position bias if pre-computed ones are not sufficient.

Check failure on line 281 in vllm/attention/backends/hpu_attn.py

View workflow job for this annotation

GitHub Actions / ruff (3.12)

Ruff (E501)

vllm/attention/backends/hpu_attn.py:281:81: E501 Line too long (95 > 80)
# Repeatedly creating position biases is memory inneficient.

Check failure on line 282 in vllm/attention/backends/hpu_attn.py

View workflow job for this annotation

GitHub Actions / ruff (3.12)

Ruff (E501)

vllm/attention/backends/hpu_attn.py:282:81: E501 Line too long (88 > 80)
position_bias = _make_prompt_alibi_bias(
alibi_slopes=self.alibi_slopes,
seq_len=max(attn_bias.size(-2), attn_bias.size(-1)),

Check failure on line 285 in vllm/attention/backends/hpu_attn.py

View workflow job for this annotation

GitHub Actions / ruff (3.12)

Ruff (E501)

vllm/attention/backends/hpu_attn.py:285:81: E501 Line too long (84 > 80)
num_kv_heads=self.num_kv_heads,
dtype=self.alibi_slopes.dtype,
)
# Offsets are kept separate from position bias.
# Combining these together would scale position_bias size by batch size.

Check failure on line 290 in vllm/attention/backends/hpu_attn.py

View workflow job for this annotation

GitHub Actions / ruff (3.12)

Ruff (E501)

vllm/attention/backends/hpu_attn.py:290:81: E501 Line too long (96 > 80)
position_bias_offset = seq_lens_tensor.unsqueeze(1).tile(1, self.num_heads).to(dtype=self.alibi_slopes.dtype)

Check failure on line 291 in vllm/attention/backends/hpu_attn.py

View workflow job for this annotation

GitHub Actions / mypy (3.9)

Item "None" of "Optional[Any]" has no attribute "unsqueeze" [union-attr]

Check failure on line 291 in vllm/attention/backends/hpu_attn.py

View workflow job for this annotation

GitHub Actions / ruff (3.12)

Ruff (E501)

vllm/attention/backends/hpu_attn.py:291:81: E501 Line too long (133 > 80)

Check failure on line 291 in vllm/attention/backends/hpu_attn.py

View workflow job for this annotation

GitHub Actions / mypy (3.10)

Item "None" of "Any | None" has no attribute "unsqueeze" [union-attr]

Check failure on line 291 in vllm/attention/backends/hpu_attn.py

View workflow job for this annotation

GitHub Actions / mypy (3.11)

Item "None" of "Any | None" has no attribute "unsqueeze" [union-attr]

Check failure on line 291 in vllm/attention/backends/hpu_attn.py

View workflow job for this annotation

GitHub Actions / mypy (3.12)

Item "None" of "Any | None" has no attribute "unsqueeze" [union-attr]
position_bias_offset.mul_(self.alibi_slopes[None, :])
position_bias_offset = position_bias_offset - position_bias[:, :, -1, 0]

Check failure on line 293 in vllm/attention/backends/hpu_attn.py

View workflow job for this annotation

GitHub Actions / ruff (3.12)

Ruff (E501)

vllm/attention/backends/hpu_attn.py:293:81: E501 Line too long (96 > 80)
else:
attn_bias = None
position_bias = None
position_bias_offset = None

out = ops.prompt_attention(
query.view(query_shape),
key.view(kv_shape),
value.view(kv_shape),
attn_bias=attn_bias,
position_bias=position_bias,
position_bias_offset=position_bias_offset,
p=0.0,
scale=self.scale,
matmul_qk_op=self.matmul_qk,
Expand Down Expand Up @@ -278,6 +330,19 @@ def forward(
output = out.reshape(batch_size, seq_len, hidden_size)
else:
# Decoding run.
self.position_bias = None
alibi_blocks = attn_metadata.alibi_blocks
if self.alibi_slopes is not None and alibi_blocks is not None:
if self.prev_attn is not None and self.prev_attn.tp_rank == self.tp_rank:
self.position_bias = self.prev_attn.position_bias
else:
self.position_bias = _make_decode_alibi_bias(
alibi_blocks=alibi_blocks,
alibi_slopes=self.alibi_slopes,
num_kv_heads=self.num_kv_heads,
dtype=self.alibi_slopes.dtype,
)

output = HPUPagedAttention.forward_decode(
query=query,
key_cache=key_cache,
Expand All @@ -288,14 +353,18 @@ def forward(
block_scales=attn_metadata.block_scales,
block_groups=attn_metadata.block_groups,
scale=self.scale,
position_bias=self.position_bias,
matmul_qk_op=self.matmul_qk,
matmul_av_op=self.matmul_av,
batch2block_matmul_op=self.batch2block_matmul,
block2batch_matmul_op=self.block2batch_matmul,
keys_fetch_func=self.k_cache.fetch_from_cache,
values_fetch_func=self.v_cache.fetch_from_cache)
values_fetch_func=self.v_cache.fetch_from_cache,
)

# Reshape the output tensor.
return output.view(batch_size, seq_len, hidden_size)
output = output.view(batch_size, seq_len, hidden_size)
return output

def forward_encoder_decoder(
self,
Expand Down Expand Up @@ -409,11 +478,11 @@ def forward_encoder_decoder(
return output.view(batch_size, -1, hidden_size)


def _make_alibi_bias(
def _make_prompt_alibi_bias(
alibi_slopes: torch.Tensor,
seq_len: int,
num_kv_heads: int,
dtype: torch.dtype,
seq_len: int,
) -> torch.Tensor:
bias = torch.arange(seq_len, dtype=dtype)
# NOTE(zhuohan): HF uses
Expand All @@ -427,15 +496,39 @@ def _make_alibi_bias(

padded_len = (seq_len + 7) // 8 * 8
num_heads = alibi_slopes.shape[0]
bias = torch.empty(
1, # batch size
per_head_bias = torch.empty(
1,
num_heads,
seq_len,
padded_len,
device=alibi_slopes.device,
dtype=dtype,
)[:, :, :, :seq_len].copy_(bias)
bias.mul_(alibi_slopes[:, None, None])
if num_heads != num_kv_heads:
bias = bias.unflatten(1, (num_kv_heads, num_heads // num_kv_heads))
return bias
)[:, :, :, :seq_len]
# NOTE(Tanner):
# .copy_ was not performing broadcasting of bias to all 32 heads in Eager mode.
per_head_bias[:, :] = bias
per_head_bias.mul_(alibi_slopes[:, None, None])

return per_head_bias


def _make_decode_alibi_bias(
alibi_blocks: torch.Tensor,
alibi_slopes: torch.Tensor,
num_kv_heads: int,
dtype: torch.dtype,
) -> torch.Tensor:
num_heads = alibi_slopes.shape[0]
per_head_bias = torch.empty(
alibi_blocks.size(0), # num blocks
num_heads,
alibi_blocks.size(-1),
device=alibi_slopes.device,
dtype=dtype,
)
# NOTE(Tanner):
# .copy_ was not performing broadcasting of bias to all 32 heads in Eager mode.
per_head_bias[:, :] = alibi_blocks.unsqueeze(-2)
per_head_bias.mul_(alibi_slopes[None, :, None])

return per_head_bias
7 changes: 5 additions & 2 deletions vllm/attention/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,11 @@ def __init__(
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
blocksparse_params: Optional[Dict[str, Any]] = None,
logits_soft_cap: Optional[float] = None,
logits_soft_cap: Optional[int] = 4096,
per_layer_sliding_window: Optional[int] = None,
tp_rank: Optional[int] = None,
prefix: str = "",
prev_attn: Optional[Any] = None,
) -> None:
super().__init__()
if per_layer_sliding_window is not None:
Expand Down Expand Up @@ -96,7 +98,8 @@ def __init__(
impl_cls = attn_backend.get_impl_cls()
self.impl = impl_cls(num_heads, head_size, scale, num_kv_heads,

Check failure on line 99 in vllm/attention/layer.py

View workflow job for this annotation

GitHub Actions / mypy (3.9)

Too many arguments for "AttentionImpl" [call-arg]

Check failure on line 99 in vllm/attention/layer.py

View workflow job for this annotation

GitHub Actions / mypy (3.10)

Too many arguments for "AttentionImpl" [call-arg]

Check failure on line 99 in vllm/attention/layer.py

View workflow job for this annotation

GitHub Actions / mypy (3.11)

Too many arguments for "AttentionImpl" [call-arg]

Check failure on line 99 in vllm/attention/layer.py

View workflow job for this annotation

GitHub Actions / mypy (3.12)

Too many arguments for "AttentionImpl" [call-arg]
alibi_slopes, sliding_window, kv_cache_dtype,
blocksparse_params, logits_soft_cap)
blocksparse_params, logits_soft_cap,
tp_rank, prev_attn)
self.num_heads = num_heads
self.head_size = head_size
self.num_kv_heads = num_kv_heads
Expand Down
1 change: 1 addition & 0 deletions vllm/attention/ops/hpu_paged_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ class HPUPagedAttentionMetadata:
block_offsets: Optional[torch.Tensor]
block_scales: Optional[torch.Tensor]
block_groups: Optional[torch.Tensor]
alibi_blocks: Optional[torch.Tensor]


class HPUPagedAttention:
Expand Down
37 changes: 25 additions & 12 deletions vllm/model_executor/models/baichuan.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
# limitations under the License.
"""Inference-only BaiChuan model compatible with HuggingFace weights."""
import math
from typing import Iterable, List, Optional, Set, Tuple, Union
from typing import Any, Iterable, List, Optional, Set, Tuple, Union

import torch
from torch import nn
Expand Down Expand Up @@ -117,6 +117,7 @@ def __init__(
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
prev_attn: Optional[Any] = None,
):
super().__init__()
self.hidden_size = hidden_size
Expand All @@ -127,7 +128,7 @@ def __init__(
self.num_heads = (self.total_num_heads //
tensor_model_parallel_world_size)
self.head_dim = hidden_size // self.total_num_heads
self.postion_embedding = position_embedding
self.position_embedding = position_embedding
self.rope_theta = rope_theta
self.max_position_embeddings = max_position_embeddings

Expand All @@ -147,7 +148,7 @@ def __init__(
quant_config=quant_config,
)
# Create the alibi slopes and slice them.
if self.postion_embedding == "ALIBI":
if self.position_embedding == "ALIBI":
tp_rank = get_tensor_model_parallel_rank()
head_start = tp_rank * self.num_heads
head_end = (tp_rank + 1) * self.num_heads
Expand All @@ -160,7 +161,11 @@ def __init__(
scaling,
alibi_slopes=alibi_slopes,
quant_config=quant_config,
prefix=f"{prefix}.attn")
logits_soft_cap=self.max_position_embeddings,
tp_rank=tp_rank,
prefix=f"{prefix}.attn",
prev_attn=None if prev_attn is None else prev_attn.attn,
)
else:
self.rotary_emb = get_rope(
self.head_dim,
Expand All @@ -174,7 +179,8 @@ def __init__(
self.scaling,
cache_config=cache_config,
quant_config=quant_config,
prefix=f"{prefix}.attn")
prefix=f"{prefix}.attn",
)

def forward(
self,
Expand All @@ -185,7 +191,7 @@ def forward(
) -> torch.Tensor:
qkv, _ = self.W_pack(hidden_states)
q, k, v = qkv.chunk(chunks=3, dim=-1)
if self.postion_embedding != "ALIBI":
if self.position_embedding != "ALIBI":
q, k = self.rotary_emb(positions, q, k)
attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
output, _ = self.o_proj(attn_output)
Expand All @@ -199,7 +205,9 @@ def __init__(self,
position_embedding: str,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = ""):
prefix: str = "",
prev_layer: Optional[Any] = None,
):
super().__init__()
self.hidden_size = config.hidden_size
rope_theta = getattr(config, "rope_theta", 10000)
Expand All @@ -214,6 +222,7 @@ def __init__(self,
cache_config=cache_config,
quant_config=quant_config,
prefix=f"{prefix}.self_attn",
prev_attn=None if prev_layer is None else prev_layer.self_attn,
)
self.mlp = BaiChuanMLP(
hidden_size=self.hidden_size,
Expand Down Expand Up @@ -280,12 +289,15 @@ def __init__(
)
self.start_layer, self.end_layer, self.layers = make_layers(
config.num_hidden_layers,
lambda prefix: BaiChuanDecoderLayer(config,
position_embedding,
cache_config,
quant_config,
prefix=prefix),
lambda prefix, prev_layer: BaiChuanDecoderLayer(config,
position_embedding,
cache_config,
quant_config,
prefix=prefix,
prev_layer=prev_layer,
),
prefix=f"{prefix}.layers",
use_layer_sharing=True,
)
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.make_empty_intermediate_tensors = (
Expand Down Expand Up @@ -372,6 +384,7 @@ def __init__(
self.lora_config = lora_config

self.quant_config = quant_config
self.use_alibi = position_embedding == "ALIBI"
self.model = BaiChuanModel(vllm_config=vllm_config,
prefix=prefix,
position_embedding=position_embedding)
Expand Down
Loading

0 comments on commit 64822b0

Please sign in to comment.