diff --git a/vllm/attention/backends/hpu_attn.py b/vllm/attention/backends/hpu_attn.py index 96dafe8c2fcb1..cffc81016aa9d 100644 --- a/vllm/attention/backends/hpu_attn.py +++ b/vllm/attention/backends/hpu_attn.py @@ -125,6 +125,8 @@ def __init__( kv_cache_dtype: str, blocksparse_params: Optional[Dict[str, Any]] = None, max_seq_len: int = 4096, + tp_rank: Optional[int] = None, + prev_attn: Optional[Any] = None, ) -> None: super(AttentionImpl, self).__init__() self.kv_cache_dtype = kv_cache_dtype @@ -142,11 +144,38 @@ def __init__( else ModuleFusedSDPA(HPUFusedSDPA) self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads self.sliding_window = sliding_window - self.alibi_slopes = alibi_slopes + self.alibi_slopes = None + self.prompt_position_bias = None + # Set upper bound on sequence length + self.max_seq_len = int(os.getenv( + 'VLLM_PROMPT_ALIBI_MAX_SEQ_LEN', + max_seq_len, + )) + # Set lower bound on sequence length + self.max_seq_len = max([ + self.max_seq_len, + int(os.getenv('VLLM_PROMPT_SEQ_BUCKET_MAX', '0')), + ]) + self.tp_rank = tp_rank + self.prev_attn = None if prev_attn is None else prev_attn.impl if alibi_slopes is not None: - alibi_slopes_tensor = torch.tensor(alibi_slopes, - dtype=torch.bfloat16) - self.alibi_slopes = alibi_slopes_tensor + if self.prev_attn is not None and self.prev_attn.tp_rank == self.tp_rank: + self.alibi_slopes = self.prev_attn.alibi_slopes + self.prompt_position_bias = self.prev_attn.prompt_position_bias + else: + slope_tensor_dtype = { + True: torch.float32, + False: torch.bfloat16, + }[os.getenv('VLLM_ALIBI_USE_FLOAT32_BIASES', '1').lower() in ['1', 'true']] + alibi_slopes_tensor = torch.tensor(alibi_slopes, + dtype=slope_tensor_dtype) + self.alibi_slopes = alibi_slopes_tensor + self.prompt_position_bias = _make_prompt_alibi_bias( + alibi_slopes=self.alibi_slopes, + seq_len=self.max_seq_len, + num_kv_heads=self.num_kv_heads, + dtype=self.alibi_slopes.dtype, + ) assert self.num_heads % self.num_kv_heads == 0 self.num_queries_per_kv = self.num_heads // self.num_kv_heads @@ -157,6 +186,12 @@ def __init__( assert alibi_slopes is None, \ 'Prefill with FusedSDPA not supported with alibi slopes!' + self.use_contiguous_pa = os.environ.get('VLLM_CONTIGUOUS_PA', + 'true').lower() == 'true' + if not self.use_contiguous_pa: + assert alibi_slopes is None, \ + 'Non-contiguous PA not supported with alibi slopes!' + suppored_head_sizes = HPUPagedAttention.get_supported_head_sizes() if head_size not in suppored_head_sizes: raise ValueError( @@ -236,21 +271,38 @@ def forward( assert attn_metadata.attn_bias is not None, \ 'attn_bias must be set before calling model.forward' attn_bias = attn_metadata.attn_bias + seq_lens_tensor = attn_metadata.seq_lens_tensor + position_bias = None + position_bias_offset = None if self.alibi_slopes is not None: - position_bias = _make_alibi_bias( - self.alibi_slopes, self.num_kv_heads, - attn_bias.dtype, attn_bias.shape[-1]) - attn_bias = attn_bias.tile( - (1, self.num_kv_heads, 1, 1)) - attn_bias.add_(position_bias) + if self.max_seq_len >= max(attn_bias.size(-2), attn_bias.size(-1)): + position_bias = self.prompt_position_bias[:, :, -attn_bias.size(-2):, -attn_bias.size(-1):] + else: + # Create new position bias if pre-computed ones are not sufficient. + # Repeatedly creating position biases is memory inneficient. + position_bias = _make_prompt_alibi_bias( + alibi_slopes=self.alibi_slopes, + seq_len=max(attn_bias.size(-2), attn_bias.size(-1)), + num_kv_heads=self.num_kv_heads, + dtype=self.alibi_slopes.dtype, + ) + # Offsets are kept separate from position bias. + # Combining these together would scale position_bias size by batch size. + position_bias_offset = seq_lens_tensor.unsqueeze(1).tile(1, self.num_heads).to(dtype=self.alibi_slopes.dtype) + position_bias_offset.mul_(self.alibi_slopes[None, :]) + position_bias_offset = position_bias_offset - position_bias[:, :, -1, 0] else: attn_bias = None + position_bias = None + position_bias_offset = None out = ops.prompt_attention( query.view(query_shape), key.view(kv_shape), value.view(kv_shape), attn_bias=attn_bias, + position_bias=position_bias, + position_bias_offset=position_bias_offset, p=0.0, scale=self.scale, matmul_qk_op=self.matmul_qk, @@ -278,6 +330,19 @@ def forward( output = out.reshape(batch_size, seq_len, hidden_size) else: # Decoding run. + self.position_bias = None + alibi_blocks = attn_metadata.alibi_blocks + if self.alibi_slopes is not None and alibi_blocks is not None: + if self.prev_attn is not None and self.prev_attn.tp_rank == self.tp_rank: + self.position_bias = self.prev_attn.position_bias + else: + self.position_bias = _make_decode_alibi_bias( + alibi_blocks=alibi_blocks, + alibi_slopes=self.alibi_slopes, + num_kv_heads=self.num_kv_heads, + dtype=self.alibi_slopes.dtype, + ) + output = HPUPagedAttention.forward_decode( query=query, key_cache=key_cache, @@ -288,14 +353,18 @@ def forward( block_scales=attn_metadata.block_scales, block_groups=attn_metadata.block_groups, scale=self.scale, + position_bias=self.position_bias, matmul_qk_op=self.matmul_qk, matmul_av_op=self.matmul_av, batch2block_matmul_op=self.batch2block_matmul, block2batch_matmul_op=self.block2batch_matmul, keys_fetch_func=self.k_cache.fetch_from_cache, - values_fetch_func=self.v_cache.fetch_from_cache) + values_fetch_func=self.v_cache.fetch_from_cache, + ) + # Reshape the output tensor. - return output.view(batch_size, seq_len, hidden_size) + output = output.view(batch_size, seq_len, hidden_size) + return output def forward_encoder_decoder( self, @@ -409,11 +478,11 @@ def forward_encoder_decoder( return output.view(batch_size, -1, hidden_size) -def _make_alibi_bias( +def _make_prompt_alibi_bias( alibi_slopes: torch.Tensor, + seq_len: int, num_kv_heads: int, dtype: torch.dtype, - seq_len: int, ) -> torch.Tensor: bias = torch.arange(seq_len, dtype=dtype) # NOTE(zhuohan): HF uses @@ -427,15 +496,39 @@ def _make_alibi_bias( padded_len = (seq_len + 7) // 8 * 8 num_heads = alibi_slopes.shape[0] - bias = torch.empty( - 1, # batch size + per_head_bias = torch.empty( + 1, num_heads, seq_len, padded_len, device=alibi_slopes.device, dtype=dtype, - )[:, :, :, :seq_len].copy_(bias) - bias.mul_(alibi_slopes[:, None, None]) - if num_heads != num_kv_heads: - bias = bias.unflatten(1, (num_kv_heads, num_heads // num_kv_heads)) - return bias + )[:, :, :, :seq_len] + # NOTE(Tanner): + # .copy_ was not performing broadcasting of bias to all 32 heads in Eager mode. + per_head_bias[:, :] = bias + per_head_bias.mul_(alibi_slopes[:, None, None]) + + return per_head_bias + + +def _make_decode_alibi_bias( + alibi_blocks: torch.Tensor, + alibi_slopes: torch.Tensor, + num_kv_heads: int, + dtype: torch.dtype, +) -> torch.Tensor: + num_heads = alibi_slopes.shape[0] + per_head_bias = torch.empty( + alibi_blocks.size(0), # num blocks + num_heads, + alibi_blocks.size(-1), + device=alibi_slopes.device, + dtype=dtype, + ) + # NOTE(Tanner): + # .copy_ was not performing broadcasting of bias to all 32 heads in Eager mode. + per_head_bias[:, :] = alibi_blocks.unsqueeze(-2) + per_head_bias.mul_(alibi_slopes[None, :, None]) + + return per_head_bias diff --git a/vllm/attention/layer.py b/vllm/attention/layer.py index 05d997279893b..97860e8179770 100644 --- a/vllm/attention/layer.py +++ b/vllm/attention/layer.py @@ -38,9 +38,11 @@ def __init__( cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, blocksparse_params: Optional[Dict[str, Any]] = None, - logits_soft_cap: Optional[float] = None, + logits_soft_cap: Optional[int] = 4096, per_layer_sliding_window: Optional[int] = None, + tp_rank: Optional[int] = None, prefix: str = "", + prev_attn: Optional[Any] = None, ) -> None: super().__init__() if per_layer_sliding_window is not None: @@ -96,7 +98,8 @@ def __init__( impl_cls = attn_backend.get_impl_cls() self.impl = impl_cls(num_heads, head_size, scale, num_kv_heads, alibi_slopes, sliding_window, kv_cache_dtype, - blocksparse_params, logits_soft_cap) + blocksparse_params, logits_soft_cap, + tp_rank, prev_attn) self.num_heads = num_heads self.head_size = head_size self.num_kv_heads = num_kv_heads diff --git a/vllm/attention/ops/hpu_paged_attn.py b/vllm/attention/ops/hpu_paged_attn.py index e55a4de11fd6c..d1235e6ec7aa7 100644 --- a/vllm/attention/ops/hpu_paged_attn.py +++ b/vllm/attention/ops/hpu_paged_attn.py @@ -22,6 +22,7 @@ class HPUPagedAttentionMetadata: block_offsets: Optional[torch.Tensor] block_scales: Optional[torch.Tensor] block_groups: Optional[torch.Tensor] + alibi_blocks: Optional[torch.Tensor] class HPUPagedAttention: diff --git a/vllm/model_executor/models/baichuan.py b/vllm/model_executor/models/baichuan.py index 5e68b7f165bf4..318e4fce30189 100644 --- a/vllm/model_executor/models/baichuan.py +++ b/vllm/model_executor/models/baichuan.py @@ -18,7 +18,7 @@ # limitations under the License. """Inference-only BaiChuan model compatible with HuggingFace weights.""" import math -from typing import Iterable, List, Optional, Set, Tuple, Union +from typing import Any, Iterable, List, Optional, Set, Tuple, Union import torch from torch import nn @@ -117,6 +117,7 @@ def __init__( cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, prefix: str = "", + prev_attn: Optional[Any] = None, ): super().__init__() self.hidden_size = hidden_size @@ -127,7 +128,7 @@ def __init__( self.num_heads = (self.total_num_heads // tensor_model_parallel_world_size) self.head_dim = hidden_size // self.total_num_heads - self.postion_embedding = position_embedding + self.position_embedding = position_embedding self.rope_theta = rope_theta self.max_position_embeddings = max_position_embeddings @@ -147,7 +148,7 @@ def __init__( quant_config=quant_config, ) # Create the alibi slopes and slice them. - if self.postion_embedding == "ALIBI": + if self.position_embedding == "ALIBI": tp_rank = get_tensor_model_parallel_rank() head_start = tp_rank * self.num_heads head_end = (tp_rank + 1) * self.num_heads @@ -160,7 +161,11 @@ def __init__( scaling, alibi_slopes=alibi_slopes, quant_config=quant_config, - prefix=f"{prefix}.attn") + logits_soft_cap=self.max_position_embeddings, + tp_rank=tp_rank, + prefix=f"{prefix}.attn", + prev_attn=None if prev_attn is None else prev_attn.attn, + ) else: self.rotary_emb = get_rope( self.head_dim, @@ -174,7 +179,8 @@ def __init__( self.scaling, cache_config=cache_config, quant_config=quant_config, - prefix=f"{prefix}.attn") + prefix=f"{prefix}.attn", + ) def forward( self, @@ -185,7 +191,7 @@ def forward( ) -> torch.Tensor: qkv, _ = self.W_pack(hidden_states) q, k, v = qkv.chunk(chunks=3, dim=-1) - if self.postion_embedding != "ALIBI": + if self.position_embedding != "ALIBI": q, k = self.rotary_emb(positions, q, k) attn_output = self.attn(q, k, v, kv_cache, attn_metadata) output, _ = self.o_proj(attn_output) @@ -199,7 +205,9 @@ def __init__(self, position_embedding: str, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, - prefix: str = ""): + prefix: str = "", + prev_layer: Optional[Any] = None, + ): super().__init__() self.hidden_size = config.hidden_size rope_theta = getattr(config, "rope_theta", 10000) @@ -214,6 +222,7 @@ def __init__(self, cache_config=cache_config, quant_config=quant_config, prefix=f"{prefix}.self_attn", + prev_attn=None if prev_layer is None else prev_layer.self_attn, ) self.mlp = BaiChuanMLP( hidden_size=self.hidden_size, @@ -280,12 +289,15 @@ def __init__( ) self.start_layer, self.end_layer, self.layers = make_layers( config.num_hidden_layers, - lambda prefix: BaiChuanDecoderLayer(config, - position_embedding, - cache_config, - quant_config, - prefix=prefix), + lambda prefix, prev_layer: BaiChuanDecoderLayer(config, + position_embedding, + cache_config, + quant_config, + prefix=prefix, + prev_layer=prev_layer, + ), prefix=f"{prefix}.layers", + use_layer_sharing=True, ) self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.make_empty_intermediate_tensors = ( @@ -372,6 +384,7 @@ def __init__( self.lora_config = lora_config self.quant_config = quant_config + self.use_alibi = position_embedding == "ALIBI" self.model = BaiChuanModel(vllm_config=vllm_config, prefix=prefix, position_embedding=position_embedding) diff --git a/vllm/model_executor/models/bloom.py b/vllm/model_executor/models/bloom.py index fee74f491acc1..14fea7cda7727 100644 --- a/vllm/model_executor/models/bloom.py +++ b/vllm/model_executor/models/bloom.py @@ -16,7 +16,8 @@ # limitations under the License. """Inference-only BLOOM model compatible with HuggingFace weights.""" import math -from typing import Iterable, List, Optional, Set, Tuple, Union +from typing import Any, Iterable, List, Optional, Set, Tuple, Union + import torch from torch import nn @@ -79,6 +80,7 @@ def __init__( cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, prefix: str = "", + prev_attn: Optional[Any] = None, ): super().__init__() self.hidden_size = config.hidden_size @@ -118,7 +120,10 @@ def __init__( alibi_slopes=alibi_slopes, cache_config=cache_config, quant_config=quant_config, - prefix=f"{prefix}.attn") + tp_rank=tp_rank, + prefix=f"{prefix}.attn", + prev_attn=None if prev_attn is None else prev_attn.attn, + ) def forward( self, @@ -171,16 +176,18 @@ def __init__( cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, prefix: str = "", + prev_layer: Optional[Any] = None, ): super().__init__() hidden_size = config.hidden_size self.input_layernorm = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon) - self.self_attention = BloomAttention(config, - cache_config, + self.self_attention = BloomAttention(config, cache_config, quant_config, - prefix=f"{prefix}.self_attention") + prefix=f"{prefix}.self_attention", + prev_attn=None if prev_layer is None else prev_layer.self_attn, + ) self.post_attention_layernorm = nn.LayerNorm( hidden_size, eps=config.layer_norm_epsilon) self.mlp = BloomMLP(config, quant_config) @@ -247,9 +254,12 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): # Transformer blocks self.start_layer, self.end_layer, self.h = make_layers( config.num_hidden_layers, - lambda prefix: BloomBlock( - config, cache_config, quant_config, prefix=prefix), - prefix=f"{prefix}.h") + lambda prefix, prev_layer: BloomBlock( + config, cache_config, quant_config, prefix=prefix, prev_layer=prev_layer + ), + prefix=f"{prefix}.h", + use_layer_sharing=True, + ) # Final Layer Norm self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon) @@ -299,6 +309,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): quant_config = vllm_config.quant_config self.config = config self.quant_config = quant_config + self.use_alibi = True self.transformer = BloomModel(vllm_config=vllm_config, prefix=maybe_prefix( prefix, "transformer")) diff --git a/vllm/model_executor/models/falcon.py b/vllm/model_executor/models/falcon.py index 8660cf79b9cdb..b52a6b01a3177 100644 --- a/vllm/model_executor/models/falcon.py +++ b/vllm/model_executor/models/falcon.py @@ -18,7 +18,8 @@ """PyTorch Falcon model.""" import math -from typing import Iterable, List, Optional, Set, Tuple, Union +from typing import Any, Iterable, List, Optional, Set, Tuple, Union + import torch from torch import nn @@ -85,6 +86,7 @@ def __init__( cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, prefix: str = "", + prev_attn: Optional[Any] = None, ): super().__init__() @@ -160,7 +162,9 @@ def __init__( self.inv_norm_factor, num_kv_heads=self.num_kv_heads, quant_config=quant_config, - prefix=f"{prefix}.attn") + logits_soft_cap=max_position_embeddings, + prefix=f"{prefix}.attn", + ) elif self.use_alibi: tp_rank = get_tensor_model_parallel_rank() head_start = tp_rank * self.num_heads @@ -174,7 +178,10 @@ def __init__( num_kv_heads=self.num_kv_heads, alibi_slopes=alibi_slopes, quant_config=quant_config, - prefix=f"{prefix}.attn") + tp_rank=tp_rank, + prefix=f"{prefix}.attn", + prev_attn=None if prev_attn is None else prev_attn.attn, + ) else: self.attn = Attention(self.num_heads, self.head_dim, @@ -182,7 +189,8 @@ def __init__( num_kv_heads=self.num_kv_heads, cache_config=cache_config, quant_config=quant_config, - prefix=f"{prefix}.attn") + prefix=f"{prefix}.attn", + ) def forward( self, @@ -246,15 +254,17 @@ def __init__( cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, prefix: str = "", + prev_layer: Optional[Any] = None, ): super().__init__() hidden_size = config.hidden_size self.num_heads = config.num_attention_heads - self.self_attention = FalconAttention( - config, - cache_config, - quant_config, - prefix=f"{prefix}.self_attention") + self.self_attention = FalconAttention(config, + cache_config, + quant_config, + prefix=f"{prefix}.self_attention", + prev_attn=None if prev_layer is None else prev_layer.self_attn, + ) self.mlp = FalconMLP(config, quant_config) self.config = config @@ -354,7 +364,6 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.config = config self.embed_dim = config.hidden_size self.num_heads = config.num_attention_heads - self.use_alibi = config.alibi # Embedding + LN Embedding self.word_embeddings = VocabParallelEmbedding( @@ -365,9 +374,15 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): # Transformer blocks self.start_layer, self.end_layer, self.h = make_layers( config.num_hidden_layers, - lambda prefix: FalconDecoderLayer( - config, cache_config, quant_config, prefix=prefix), - prefix=f"{prefix}.h") + lambda prefix, prev_layer: FalconDecoderLayer(config, + cache_config, + quant_config, + prefix=prefix, + prev_layer=prev_layer, + ), + prefix=f"{prefix}.h", + use_layer_sharing=True, + ) # Final Layer Norm self.ln_f = LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon) @@ -419,6 +434,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): quant_config = vllm_config.quant_config self.config = config self.quant_config = quant_config + self.use_alibi = config.alibi self.transformer = FalconModel(vllm_config=vllm_config, prefix=maybe_prefix( prefix, "transformer")) diff --git a/vllm/model_executor/models/jais.py b/vllm/model_executor/models/jais.py index 8c81dff6b5768..c42c97927b31f 100644 --- a/vllm/model_executor/models/jais.py +++ b/vllm/model_executor/models/jais.py @@ -19,7 +19,8 @@ """Inference-only Jais model compatible with HuggingFace weights.""" import math -from typing import Iterable, List, Optional, Set, Tuple, Union +from typing import Any, Iterable, List, Optional, Set, Tuple, Union + import torch from torch import nn @@ -77,6 +78,7 @@ def __init__( cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, prefix: str = "", + prev_attn: Optional[Any] = None, ): super().__init__() self.hidden_size = config.hidden_size @@ -116,7 +118,11 @@ def __init__( alibi_slopes=alibi_slopes, cache_config=cache_config, quant_config=quant_config, - prefix=f"{prefix}.attn") + logits_soft_cap=config.max_position_embeddings, + tp_rank=tp_rank, + prefix=f"{prefix}.attn", + prev_attn=None if prev_attn is None else prev_attn.attn, + ) def forward( self, @@ -181,6 +187,7 @@ def __init__( cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, prefix: str = "", + prev_layer: Optional[Any] = None, ): super().__init__() hidden_size = config.hidden_size @@ -191,7 +198,9 @@ def __init__( self.attn = JAISAttention(config, cache_config, quant_config, - prefix=f"{prefix}.attn") + prefix=f"{prefix}.attn", + prev_attn=None if prev_layer is None else prev_layer.self_attn, + ) self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon) self.mlp = JAISMLP(inner_dim, config, quant_config) @@ -245,11 +254,14 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.start_layer, self.end_layer, self.h = make_layers( config.num_hidden_layers, - lambda prefix: JAISBlock(config=config, - cache_config=cache_config, - quant_config=quant_config, - prefix=prefix), + lambda prefix, prev_layer: JAISBlock(config=config, + cache_config=cache_config, + quant_config=quant_config, + prefix=prefix, + prev_layer=prev_layer, + ), prefix=f"{prefix}.h", + use_layer_sharing=True, ) self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon) @@ -304,6 +316,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): quant_config = vllm_config.quant_config self.config = config self.quant_config = quant_config + self.use_alibi = config.position_embedding_type == "alibi" self.transformer = JAISModel(vllm_config=vllm_config, prefix=maybe_prefix( prefix, "transformer")) diff --git a/vllm/model_executor/models/mpt.py b/vllm/model_executor/models/mpt.py index 1235816413a44..a0416d7c39d97 100644 --- a/vllm/model_executor/models/mpt.py +++ b/vllm/model_executor/models/mpt.py @@ -1,6 +1,6 @@ # Adapted from https://huggingface.co/mosaicml/mpt-7b/tree/main import math -from typing import Iterable, List, Optional, Set, Tuple, Union +from typing import Any, Iterable, List, Optional, Set, Tuple, Union import torch import torch.nn as nn @@ -51,6 +51,7 @@ def __init__( cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, prefix: str = "", + prev_attn: Optional[Any] = None, ): super().__init__() self.d_model = config.d_model @@ -59,6 +60,7 @@ def __init__( self.clip_qkv = config.attn_config["clip_qkv"] self.qk_ln = config.attn_config["qk_ln"] self.alibi_bias_max = config.attn_config["alibi_bias_max"] + self.max_seq_len = config.max_seq_len if "kv_n_heads" in config.attn_config: self.total_num_kv_heads = config.attn_config['kv_n_heads'] else: @@ -117,7 +119,11 @@ def __init__( num_kv_heads=self.num_kv_heads, cache_config=cache_config, quant_config=quant_config, - prefix=f"{prefix}.attn") + logits_soft_cap=self.max_seq_len, + tp_rank=tp_rank, + prefix=f"{prefix}.attn", + prev_attn=None if prev_attn is None else prev_attn.attn, + ) def forward( self, @@ -179,6 +185,7 @@ def __init__( cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, prefix: str = "", + prev_layer: Optional[Any] = None ): super().__init__() hidden_size = config.d_model @@ -186,7 +193,9 @@ def __init__( self.attn = MPTAttention(config, cache_config, quant_config, - prefix=f"{prefix}.attn") + prefix=f"{prefix}.attn", + prev_attn=None if prev_layer is None else prev_layer.attn, + ) self.norm_2 = nn.LayerNorm(hidden_size) self.ffn = MPTMLP(config, quant_config) @@ -230,9 +239,11 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): ) self.start_layer, self.end_layer, self.blocks = make_layers( config.n_layers, - lambda prefix: MPTBlock( - config, cache_config, quant_config, prefix=prefix), - prefix=f"{prefix}.blocks") + lambda prefix, prev_layer: MPTBlock( + config, cache_config, quant_config, prefix=prefix, prev_layer=prev_layer), + prefix=f"{prefix}.blocks", + use_layer_sharing=True, + ) self.norm_f = nn.LayerNorm(config.d_model) if config.no_bias: for module in self.modules(): @@ -288,7 +299,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.config = config assert config.tie_word_embeddings self.quant_config = quant_config - + self.use_alibi = config.attn_config['alibi'] self.transformer = MPTModel(vllm_config=vllm_config, prefix=maybe_prefix(prefix, "transformer")) self.lm_head = self.transformer.wte diff --git a/vllm/model_executor/models/utils.py b/vllm/model_executor/models/utils.py index 7a1e1f9bf2be4..26aec62372f48 100644 --- a/vllm/model_executor/models/utils.py +++ b/vllm/model_executor/models/utils.py @@ -425,7 +425,7 @@ def merge_multimodal_embeddings( class LayerFn(Protocol): - def __call__(self, prefix: str) -> torch.nn.Module: + def __call__(self, prefix: str, prev_layer: Optional[Any] = None) -> torch.nn.Module: ... @@ -508,6 +508,7 @@ def make_layers( num_hidden_layers: int, layer_fn: LayerFn, prefix: str, + use_layer_sharing: bool = False, ) -> Tuple[int, int, torch.nn.ModuleList]: """Make a list of layers with the given layer function, taking pipeline parallelism into account. @@ -517,11 +518,25 @@ def make_layers( start_layer, end_layer = get_pp_indices(num_hidden_layers, get_pp_group().rank_in_group, get_pp_group().world_size) - modules = torch.nn.ModuleList( - [PPMissingLayer() for _ in range(start_layer)] + [ - maybe_offload_to_cpu(layer_fn(prefix=f"{prefix}.{idx}")) - for idx in range(start_layer, end_layer) - ] + [PPMissingLayer() for _ in range(end_layer, num_hidden_layers)]) + layers = [] + for _ in range(start_layer): + curr_layer = PPMissingLayer() + layers.append(curr_layer) + + curr_layer = None + for idx in range(start_layer, end_layer): + if use_layer_sharing: + curr_layer = layer_fn(prefix=f"{prefix}.{idx}", prev_layer=curr_layer) + else: + curr_layer = layer_fn(prefix=f"{prefix}.{idx}") + layers.append(maybe_offload_to_cpu(curr_layer)) + + for _ in range(end_layer, num_hidden_layers): + curr_layer = PPMissingLayer() + layers.append(curr_layer) + + modules = nn.ModuleList(layers) + return start_layer, end_layer, modules diff --git a/vllm/worker/hpu_model_runner.py b/vllm/worker/hpu_model_runner.py index 48c4af5f915fa..959ee632c6050 100755 --- a/vllm/worker/hpu_model_runner.py +++ b/vllm/worker/hpu_model_runner.py @@ -733,6 +733,7 @@ def load_model(self) -> None: self.model = self.model.to("hpu") htcore.mark_step() + self.use_alibi = hasattr(self.model, "use_alibi") and self.model.use_alibi hidden_layer_markstep_interval = int( os.getenv('VLLM_CONFIG_HIDDEN_LAYERS', '1')) model_config = getattr(self.model, "config", None) @@ -975,11 +976,11 @@ def _prepare_prompt( block_list=prefix_block_list_tensor, block_mapping=None, block_usage=None, - block_indices=None, - block_offsets=None, - block_scales=None, + block_indices=None, # Set by later "precompute_indices_and_offsets" function call + block_offsets=None, # Set by later "precompute_indices_and_offsets" function call + block_scales=None, # Set by later "_set_block_scales" function call block_groups=None, - attn_bias=None, + attn_bias=None, # Set by later "_set_attn_bias" function call seq_lens=seq_lens, seq_lens_tensor=seq_lens_tensor, context_lens_tensor=context_lens_tensor, @@ -987,8 +988,9 @@ def _prepare_prompt( num_prefill_tokens=num_prefill_tokens, num_decode_tokens=0, slot_mapping=slot_mapping, - multi_modal_placeholder_index_maps= - None # FIXME(kzawora): mutli-modality will not work here + alibi_blocks=None, + # FIXME(kzawora): mutli-modality will not work here + multi_modal_placeholder_index_maps=None, ) multi_modal_kwargs = MultiModalKwargs.batch(multi_modal_kwargs_list) for t in multi_modal_kwargs: @@ -1127,7 +1129,15 @@ def _prepare_decode( block_groups = padding_fn(block_groups, -1) block_usage = padding_fn(block_usage, 1) - block_list = torch.tensor(block_list, dtype=torch.int, device='cpu') + alibi_blocks = None + if self.use_alibi: + alibi_blocks = self._compute_alibi_block(block_tables, seq_lens, len(block_groups)) + alibi_blocks = alibi_blocks.to( # type: ignore + self.device, non_blocking=True) + + block_list = torch.tensor(block_list, + dtype=torch.int, + device='cpu') block_groups = torch.tensor(block_groups, dtype=torch.int, device='cpu') @@ -1154,20 +1164,22 @@ def _prepare_decode( attn_metadata = self.attn_backend.make_metadata( is_prompt=False, block_list=block_list, - block_mapping=None, + block_mapping=None, # Set by later "_set_block_mapping" function call block_usage=block_usage, - block_indices=None, - block_offsets=None, - block_scales=None, + block_indices=None, # Set by later "precompute_indices_and_offsets" function call + block_offsets=None, # Set by later "precompute_indices_and_offsets" function call + block_scales=None, # Set by later "_set_block_scales" function call block_groups=block_groups, - attn_bias=None, + attn_bias=None, # Set by later "_set_block_mapping" function call seq_lens_tensor=None, context_lens_tensor=None, num_prefills=0, num_prefill_tokens=0, num_decode_tokens=num_decode_tokens, slot_mapping=slot_mapping, - multi_modal_placeholder_index_maps=None) + alibi_blocks=alibi_blocks, + multi_modal_placeholder_index_maps=None, + ) return PrepareDecodeMetadata(input_tokens=input_tokens, input_positions=input_positions, attn_metadata=attn_metadata, @@ -1177,6 +1189,30 @@ def _prepare_decode( slot_mapping=slot_mapping, lora_ids=lora_ids) + def _compute_alibi_block(self, block_tables, seq_lens, num_blocks): + # Create intermediary and output strctures + max_block_table_len = max(len(block_table) for block_table in block_tables) + alibi_offsets = torch.arange(-max_block_table_len * self.block_size + 1, 1, dtype=torch.long, device='cpu') + alibi_blocks = torch.zeros((num_blocks, self.block_size), dtype=torch.long, device='cpu') + + # Assign biases per token + for batch_idx in range(len(block_tables)): + seq_len = seq_lens[batch_idx] + for seq_idx in range(len(block_tables[batch_idx])): + block_idx = block_tables[batch_idx][seq_idx] + + # Calculate the number of valid positions in the current block + valid_length = seq_len - seq_idx * self.block_size + if valid_length > 0: + current_block_length = min(valid_length, self.block_size) + offset_end = current_block_length - valid_length + if offset_end == 0: + alibi_blocks[block_idx][:current_block_length] = alibi_offsets[-valid_length:] + else: + alibi_blocks[block_idx][:current_block_length] = alibi_offsets[-valid_length:offset_end] + + return alibi_blocks + def prepare_input_tensors( self, seq_group_metadata_list: List[SequenceGroupMetadata], @@ -1381,6 +1417,7 @@ def trim_attn_metadata(self, metadata: AttentionMetadata) -> object: 'block_offsets', 'block_scales', 'block_groups', + 'alibi_blocks', ]) return attention_metadata @@ -1389,7 +1426,8 @@ def create_dummy_seq_group_metadata(self, seq_len, is_prompt, lora_request=None, - temperature=0): + temperature=0, + last_block_assigned=0): sampling_params = SamplingParams(temperature=temperature) num_blocks = math.ceil(seq_len / self.block_size) seq_len = max(seq_len, 1) @@ -1400,7 +1438,13 @@ def create_dummy_seq_group_metadata(self, else: input_len = seq_len - 1 output_len = 1 - block_tables = {group_id: [_PAD_BLOCK_ID] * num_blocks} + # NOTE(Tanner): + # ALiBI biases fail if block_tables for dummy sequences are all zeros. + # By default "_PAD_BLOCK_ID" is "0" and this is not a realistic value for block tables. + block_tables = {group_id: []} + for block_idx in range(num_blocks): + last_block_assigned += 1 + block_tables[group_id] += [last_block_assigned] prompt_token_ids = [0] * input_len output_token_ids = [1] * output_len prompt_token_ids_array = array('l', prompt_token_ids) # noqa: F821 @@ -1474,18 +1518,26 @@ def warmup_scenario(self, temperature=temperature) for i in range(batch_size) ] else: - # FIXME: seq_len is actually number of blocks - blocks = [seq_len // batch_size for _ in range(batch_size)] - blocks[0] += seq_len % batch_size - seqs = [ - self.create_dummy_seq_group_metadata( + # NOTE(Tanner): + # seq_len is num blocks + # Here we assign as many blocks to each sequence as we can + blocks_per_seq = (seq_len - 1) // batch_size + extra_blocks = (seq_len - 1) % batch_size + blocks = [blocks_per_seq + (1 if i < extra_blocks else 0) for i in range(batch_size)] + seqs = [] + last_block_assigned = 0 + for i, b in enumerate(blocks): + seqs += [self.create_dummy_seq_group_metadata( i, - b * self.block_size - 1, + b * self.block_size, is_prompt, lora_request=dummy_lora_requests_per_seq[i] if dummy_lora_requests_per_seq else None, - temperature=temperature) for i, b in enumerate(blocks) - ] + temperature=temperature, + last_block_assigned=last_block_assigned, + )] + if len(seqs[-1].block_tables[i]) > 0: + last_block_assigned = seqs[-1].block_tables[i][-1] torch.hpu.synchronize() profiler = None if is_pt_profiler_run and self.is_driver_worker: @@ -1952,7 +2004,7 @@ def create_lora_mask(self, input_tokens: torch.Tensor, lora_ids: List[int], This is a helper function to create the mask for lora computations. Lora Mask is needed to ensure we match the correct lora weights for the for the request. - For Prompt phase we have + For Prompt phase we have lora_mask with shape (batch_size * seq_len, max_loras * max_rank) lora_logits_mask with shape (batch_size, max_loras * max_rank) For Decode phase we have both