diff --git a/vllm/attention/backends/abstract.py b/vllm/attention/backends/abstract.py index 5be2d83346d00..aed04361e5fb4 100644 --- a/vllm/attention/backends/abstract.py +++ b/vllm/attention/backends/abstract.py @@ -247,5 +247,6 @@ def forward( k_scale: float = 1.0, v_scale: float = 1.0, attn_type: str = AttentionType.DECODER, + output: Optional[torch.Tensor] = None, ) -> torch.Tensor: raise NotImplementedError diff --git a/vllm/attention/backends/blocksparse_attn.py b/vllm/attention/backends/blocksparse_attn.py index 9e54c3b40c54e..99cb84346d84e 100644 --- a/vllm/attention/backends/blocksparse_attn.py +++ b/vllm/attention/backends/blocksparse_attn.py @@ -360,6 +360,7 @@ def forward( k_scale: float = 1.0, v_scale: float = 1.0, attn_type: str = AttentionType.DECODER, + output: Optional[torch.Tensor] = None, ) -> torch.Tensor: """Forward pass with FlashAttention and PagedAttention. @@ -448,5 +449,6 @@ def forward( blocksparse_head_sliding_step=self.head_sliding_step, ) + assert output is not None # Reshape the output tensor. return output.view(num_tokens, hidden_size) diff --git a/vllm/attention/backends/flash_attn.py b/vllm/attention/backends/flash_attn.py index 32738d1043b1d..c69e12ad78c44 100644 --- a/vllm/attention/backends/flash_attn.py +++ b/vllm/attention/backends/flash_attn.py @@ -638,24 +638,27 @@ def forward( k_scale: float = 1.0, v_scale: float = 1.0, attn_type: str = AttentionType.DECODER, + output: Optional[torch.Tensor] = None, ) -> torch.Tensor: """Forward pass with FlashAttention. Args: - query: shape = [num_tokens, num_heads * head_size] - key: shape = [num_tokens, num_kv_heads * head_size] - value: shape = [num_tokens, num_kv_heads * head_size] + query: shape = [num_tokens, num_heads, head_size] + key: shape = [num_tokens, num_kv_heads, head_size] + value: shape = [num_tokens, num_kv_heads, head_size] + output: shape = [num_tokens, num_heads, head_size] kv_cache = [2, num_blocks, block_size, num_kv_heads, head_size] NOTE: kv_cache will be an empty tensor with shape [0] for profiling run. attn_metadata: Metadata for attention. - Returns: - shape = [num_tokens, num_heads * head_size] + NOTE: It in-place updates the output tensor. """ # NOTE(woosuk): FlashAttention does not support FP8 KV cache. assert k_scale == 1.0 and v_scale == 1.0, ( "key/v_scale is not supported in FlashAttention.") + assert output is not None, "Output tensor must be provided." + if (attn_type == AttentionType.ENCODER and (not attn_metadata.is_all_encoder_attn_metadata_set)): raise AttributeError("Encoder attention requires setting " @@ -666,23 +669,12 @@ def forward( "requires setting cross-attention " "metadata attributes.") - num_heads: int = self.num_heads - head_size: int = self.head_size - num_kv_heads: int = self.num_kv_heads kv_cache_dtype: str = self.kv_cache_dtype softmax_scale: float = self.scale window_size = self.sliding_window alibi_slopes: Optional[torch.Tensor] = self.alibi_slopes logits_soft_cap: Optional[float] = self.logits_soft_cap - num_tokens, hidden_size = query.shape - - # Reshape the query, key, and value tensors. - query = query.view(-1, num_heads, head_size) - if (key is not None) and (value is not None): - key = key.view(-1, num_kv_heads, head_size) - value = value.view(-1, num_kv_heads, head_size) - if kv_cache.numel() > 0: key_cache = kv_cache[0] value_cache = kv_cache[1] @@ -721,13 +713,13 @@ def forward( num_decode_query_tokens) = \ get_num_prefill_decode_query_kv_tokens(attn_metadata, attn_type) decode_query = query[num_prefill_query_tokens:] + decode_output = output[num_prefill_query_tokens:] # QKV for prefill. query = query[:num_prefill_query_tokens] + prefill_output = output[:num_prefill_query_tokens] assert query.shape[0] == num_prefill_query_tokens assert decode_query.shape[0] == num_decode_query_tokens - prefill_output: Optional[torch.Tensor] = None - decode_output: Optional[torch.Tensor] = None if prefill_meta := attn_metadata.prefill_metadata: # Prompt run. if (kv_cache.numel() == 0 or prefill_meta.block_tables is None @@ -741,7 +733,7 @@ def forward( key = key[:num_prefill_kv_tokens] value = value[:num_prefill_kv_tokens] - prefill_output = flash_attn_varlen_func( + flash_attn_varlen_func( q=query, k=key, v=value, @@ -754,6 +746,7 @@ def forward( window_size=window_size, alibi_slopes=alibi_slopes, softcap=logits_soft_cap, + out=prefill_output, ) else: # prefix-enabled attention @@ -761,7 +754,7 @@ def forward( "Only decoder-only models support prefix caching") assert prefill_meta.seq_lens is not None max_seq_len = max(prefill_meta.seq_lens) - prefill_output = flash_attn_varlen_func( # noqa + flash_attn_varlen_func( # noqa q=query, k=key_cache, v=value_cache, @@ -775,6 +768,7 @@ def forward( alibi_slopes=alibi_slopes, block_table=prefill_meta.block_tables, softcap=logits_soft_cap, + out=prefill_output, ) if decode_meta := attn_metadata.decode_metadata: @@ -788,7 +782,7 @@ def forward( assert attn_type == AttentionType.DECODER, ( "Only decoder-only models support max_decode_query_len > 1" ) - decode_output = flash_attn_varlen_func( + flash_attn_varlen_func( q=decode_query, k=key_cache, v=value_cache, @@ -802,6 +796,7 @@ def forward( alibi_slopes=alibi_slopes, softcap=logits_soft_cap, block_table=decode_meta.block_tables, + out=decode_output, ) else: # Use flash_attn_with_kvcache for normal decoding. @@ -810,7 +805,7 @@ def forward( _, block_tables_arg, ) = get_seq_len_block_table_args(decode_meta, False, attn_type) - decode_output = flash_attn_with_kvcache( + flash_attn_with_kvcache( q=decode_query.unsqueeze(1), k_cache=key_cache, v_cache=value_cache, @@ -821,20 +816,8 @@ def forward( window_size=window_size, alibi_slopes=alibi_slopes, softcap=logits_soft_cap, - ).squeeze(1) - - if prefill_output is None: - assert decode_output is not None - return decode_output.view(num_decode_query_tokens, hidden_size) - if decode_output is None: - assert prefill_output is not None - return prefill_output.view(num_prefill_query_tokens, hidden_size) - - assert decode_meta is not None - decode_output = decode_output.squeeze(1) - output = torch.cat([prefill_output, decode_output], dim=0) - return output.view(num_tokens, hidden_size) - + out=decode_output.unsqueeze(1), + ) return output diff --git a/vllm/attention/backends/flashinfer.py b/vllm/attention/backends/flashinfer.py index 1a2024705eb04..e367468d05d26 100644 --- a/vllm/attention/backends/flashinfer.py +++ b/vllm/attention/backends/flashinfer.py @@ -774,7 +774,11 @@ def forward( k_scale: float = 1.0, v_scale: float = 1.0, attn_type: str = AttentionType.DECODER, + output: Optional[torch.Tensor] = None, ) -> torch.Tensor: + + # TODO: directly write to output tensor + if attn_type != AttentionType.DECODER: raise NotImplementedError("Encoder self-attention and " "encoder/decoder cross-attention " diff --git a/vllm/attention/backends/hpu_attn.py b/vllm/attention/backends/hpu_attn.py index 5359941d41fde..2c62e565c04c7 100644 --- a/vllm/attention/backends/hpu_attn.py +++ b/vllm/attention/backends/hpu_attn.py @@ -145,6 +145,7 @@ def forward( k_scale: float = 1.0, v_scale: float = 1.0, attn_type: str = AttentionType.DECODER, + output: Optional[torch.Tensor] = None, ) -> torch.Tensor: """Forward pass with xFormers and PagedAttention. diff --git a/vllm/attention/backends/ipex_attn.py b/vllm/attention/backends/ipex_attn.py index 3b0d51ea4a3d8..21949874bea47 100644 --- a/vllm/attention/backends/ipex_attn.py +++ b/vllm/attention/backends/ipex_attn.py @@ -173,6 +173,7 @@ def forward( k_scale: float = 1.0, v_scale: float = 1.0, attn_type: str = AttentionType.DECODER, + output: Optional[torch.Tensor] = None, ) -> torch.Tensor: """Forward pass with IPEX varlen_attention and PagedAttention. diff --git a/vllm/attention/backends/pallas.py b/vllm/attention/backends/pallas.py index 5988be0e6b687..9809aed0e66f9 100644 --- a/vllm/attention/backends/pallas.py +++ b/vllm/attention/backends/pallas.py @@ -151,6 +151,7 @@ def forward( k_scale: float = 1.0, v_scale: float = 1.0, attn_type: str = AttentionType.DECODER, + output: Optional[torch.Tensor] = None, ) -> torch.Tensor: """Forward pass with Pallas attention. diff --git a/vllm/attention/backends/rocm_flash_attn.py b/vllm/attention/backends/rocm_flash_attn.py index 6a494f4e73cb4..9139c3c1314d8 100644 --- a/vllm/attention/backends/rocm_flash_attn.py +++ b/vllm/attention/backends/rocm_flash_attn.py @@ -415,6 +415,7 @@ def forward( k_scale: float = 1.0, v_scale: float = 1.0, attn_type: str = AttentionType.DECODER, + output: Optional[torch.Tensor] = None, ) -> torch.Tensor: """Forward pass with FlashAttention and PagedAttention. diff --git a/vllm/attention/backends/torch_sdpa.py b/vllm/attention/backends/torch_sdpa.py index dafa5bb56acda..86e952a903f36 100644 --- a/vllm/attention/backends/torch_sdpa.py +++ b/vllm/attention/backends/torch_sdpa.py @@ -431,6 +431,7 @@ def forward( k_scale: float = 1.0, v_scale: float = 1.0, attn_type: str = AttentionType.DECODER, + output: Optional[torch.Tensor] = None, ) -> torch.Tensor: """Forward pass with torch SDPA and PagedAttention. diff --git a/vllm/attention/backends/xformers.py b/vllm/attention/backends/xformers.py index 292575a8736bc..e2e989efb020c 100644 --- a/vllm/attention/backends/xformers.py +++ b/vllm/attention/backends/xformers.py @@ -417,6 +417,7 @@ def forward( k_scale: float = 1.0, v_scale: float = 1.0, attn_type: str = AttentionType.DECODER, + output: Optional[torch.Tensor] = None, ) -> torch.Tensor: """Forward pass with xFormers and PagedAttention. diff --git a/vllm/attention/layer.py b/vllm/attention/layer.py index 17157617248f7..e024eef286f05 100644 --- a/vllm/attention/layer.py +++ b/vllm/attention/layer.py @@ -4,7 +4,6 @@ import torch import torch.nn as nn -import vllm.envs as envs from vllm.attention import AttentionMetadata, AttentionType from vllm.attention.selector import backend_name_to_enum, get_attn_backend from vllm.config import CacheConfig, get_current_vllm_config @@ -12,7 +11,7 @@ from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod -from vllm.platforms import current_platform +from vllm.platforms import _Backend, current_platform from vllm.utils import direct_register_custom_op @@ -97,14 +96,23 @@ def __init__( self.impl = impl_cls(num_heads, head_size, scale, num_kv_heads, alibi_slopes, sliding_window, kv_cache_dtype, blocksparse_params, logits_soft_cap) + self.num_heads = num_heads + self.head_size = head_size + self.num_kv_heads = num_kv_heads self.backend = backend_name_to_enum(attn_backend.get_name()) # For cuda-alike (CUDA and ROCM) and cpu platforms, we control how # torch.compile works by registering the attention as one giant # opaque custom op. For other platforms, we directly call them # and let torch.compile handle them. - self.use_direct_call = envs.VLLM_USE_V1 or not ( - current_platform.is_cuda_alike() or current_platform.is_cpu()) + self.use_direct_call = not current_platform.is_cuda_alike( + ) and not current_platform.is_cpu() + + # For some attention backends, we allocate an output tensor before + # calling the custom op. When piecewise cudagraph is enabled, this + # makes sure the output tensor is allocated inside the cudagraph. + self.use_output = self.backend == _Backend.FLASH_ATTN or \ + self.backend == _Backend.FLASH_ATTN_VLLM_V1 compilation_config = get_current_vllm_config().compilation_config if prefix in compilation_config.static_forward_context: raise ValueError(f"Duplicate layer name: {prefix}") @@ -130,6 +138,22 @@ def forward( self._k_scale, self._v_scale, attn_type=attn_type) + elif self.use_output: + output = torch.empty_like(query) + hidden_size = query.size(-1) + # Reshape the query, key, and value tensors. + # NOTE(woosuk): We do this outside the custom op to minimize the + # CPU overheads from the non-CUDA-graph regions. + query = query.view(-1, self.num_heads, self.head_size) + output = output.view(-1, self.num_heads, self.head_size) + if key is not None: + key = key.view(-1, self.num_kv_heads, self.head_size) + if value is not None: + value = value.view(-1, self.num_kv_heads, self.head_size) + torch.ops.vllm.unified_attention_with_output( + query, key, value, output, kv_cache, attn_type, + self.layer_name) + return output.view(-1, hidden_size) else: return torch.ops.vllm.unified_attention(query, key, value, kv_cache, attn_type, @@ -183,3 +207,47 @@ def unified_attention_fake( fake_impl=unified_attention_fake, dispatch_key=current_platform.dispatch_key, ) + + +def unified_attention_with_output( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + output: torch.Tensor, + kv_cache: torch.Tensor, + attn_type: str, + layer_name: str, +) -> None: + forward_context: ForwardContext = get_forward_context() + attn_metadata = forward_context.dynamic_forward_context + self = forward_context.static_forward_context[layer_name] + self.impl.forward(query, + key, + value, + kv_cache, + attn_metadata, + self._k_scale, + self._v_scale, + attn_type=attn_type, + output=output) + + +def unified_attention_with_output_fake( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + output: torch.Tensor, + kv_cache: torch.Tensor, + attn_type: str, + layer_name: str, +) -> None: + return + + +direct_register_custom_op( + op_name="unified_attention_with_output", + op_func=unified_attention_with_output, + mutates_args=["kv_cache", "output"], + fake_impl=unified_attention_with_output_fake, + dispatch_key=current_platform.dispatch_key, +) diff --git a/vllm/config.py b/vllm/config.py index 510bd81d66217..5f50d65ec87e1 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -2238,7 +2238,7 @@ class CompilationConfig(BaseModel): custom_ops: List[str] = Field(default_factory=list) splitting_ops: List[str] = Field(default_factory=lambda: [ "vllm.unified_attention", - "vllm.unified_v1_flash_attention", + "vllm.unified_attention_with_output", ]) use_inductor: bool = True diff --git a/vllm/v1/attention/backends/flash_attn.py b/vllm/v1/attention/backends/flash_attn.py index 4aa4b296f0efc..d37989055c2e5 100644 --- a/vllm/v1/attention/backends/flash_attn.py +++ b/vllm/v1/attention/backends/flash_attn.py @@ -6,8 +6,6 @@ from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, AttentionMetadata, AttentionType) -from vllm.forward_context import get_forward_context -from vllm.utils import direct_register_custom_op from vllm.vllm_flash_attn import flash_attn_varlen_func @@ -113,13 +111,14 @@ def forward( k_scale: float = 1.0, v_scale: float = 1.0, attn_type: AttentionType = AttentionType.DECODER, + output: Optional[torch.Tensor] = None, ) -> torch.Tensor: """Forward pass with FlashAttention. Args: - query: shape = [num_tokens, num_heads * head_size] - key: shape = [num_tokens, num_kv_heads * head_size] - value: shape = [num_tokens, num_kv_heads * head_size] + query: shape = [num_tokens, num_heads, head_size] + key: shape = [num_tokens, num_kv_heads, head_size] + value: shape = [num_tokens, num_kv_heads, head_size] kv_cache = [2, num_blocks, block_size, num_kv_heads, head_size] attn_metadata: Metadata for attention. Returns: @@ -135,118 +134,42 @@ def forward( assert k_scale == 1.0 and v_scale == 1.0, ( "key/v_scale is not supported in FlashAttention.") - # Reshape the query, key, and value tensors. - # NOTE(woosuk): We do this outside the custom op to minimize the CPU - # overheads from the non-CUDA-graph regions. - query = query.view(-1, self.num_heads, self.head_size) - key = key.view(-1, self.num_kv_heads, self.head_size) - value = value.view(-1, self.num_kv_heads, self.head_size) - - output = torch.empty_like(query) - torch.ops.vllm.unified_v1_flash_attention( - output, - query, - key, - value, - self.num_heads, - self.head_size, - self.num_kv_heads, - kv_cache, + if attn_metadata is None: + # Profiling run. + return output + + num_actual_tokens = attn_metadata.num_actual_tokens + + # Reshape the input keys and values and store them in the cache. + key_cache = kv_cache[0] + value_cache = kv_cache[1] + torch.ops._C_cache_ops.reshape_and_cache_flash( + key[:num_actual_tokens], + value[:num_actual_tokens], + key_cache, + value_cache, + attn_metadata.slot_mapping, self.kv_cache_dtype, k_scale, v_scale, - self.scale, - self.sliding_window, - self.alibi_slopes, - self.logits_soft_cap, ) - return output.view(-1, self.num_heads * self.head_size) - - -def unified_v1_flash_attention( - output: torch.Tensor, - query: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, - num_heads: int, - head_size: int, - num_kv_heads: int, - kv_cache: torch.Tensor, - kv_cache_dtype: str, - k_scale: float, - v_scale: float, - softmax_scale: float, - window_size: Optional[List[int]] = None, - alibi_slopes: Optional[torch.Tensor] = None, - logits_soft_cap: Optional[float] = None, -) -> None: - context = get_forward_context() - current_metadata = context.dynamic_forward_context - if current_metadata is None: - # Profiling run. - return - - assert current_metadata is not None - assert isinstance(current_metadata, FlashAttentionMetadata) - attn_metadata: FlashAttentionMetadata = current_metadata - num_actual_tokens = attn_metadata.num_actual_tokens - - # Reshape the input keys and values and store them in the cache. - key_cache = kv_cache[0] - value_cache = kv_cache[1] - torch.ops._C_cache_ops.reshape_and_cache_flash( - key[:num_actual_tokens], - value[:num_actual_tokens], - key_cache, - value_cache, - attn_metadata.slot_mapping, - kv_cache_dtype, - k_scale, - v_scale, - ) - - # Compute attention and update output up to `num_actual_tokens`. - flash_attn_varlen_func( - q=query[:num_actual_tokens], - k=key_cache, - v=value_cache, - out=output[:num_actual_tokens], - cu_seqlens_q=attn_metadata.query_start_loc, - max_seqlen_q=attn_metadata.max_query_len, - cu_seqlens_k=attn_metadata.seq_start_loc, - max_seqlen_k=attn_metadata.max_seq_len, - softmax_scale=softmax_scale, - causal=True, - alibi_slopes=alibi_slopes, - window_size=window_size, - block_table=attn_metadata.block_table, - softcap=logits_soft_cap, - ) - - -def unified_v1_flash_attention_fake( - output: torch.Tensor, - query: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, - num_heads: int, - head_size: int, - num_kv_heads: int, - kv_cache: torch.Tensor, - kv_cache_dtype: str, - k_scale: float, - v_scale: float, - softmax_scale: float, - window_size: Optional[List[int]] = None, - alibi_slopes: Optional[torch.Tensor] = None, - logits_soft_cap: Optional[float] = None, -) -> None: - return - - -direct_register_custom_op( - op_name="unified_v1_flash_attention", - op_func=unified_v1_flash_attention, - mutates_args=["kv_cache", "output"], - fake_impl=unified_v1_flash_attention_fake, -) + + # Compute attention and update output up to `num_actual_tokens`. + flash_attn_varlen_func( + q=query[:num_actual_tokens], + k=key_cache, + v=value_cache, + out=output[:num_actual_tokens], + cu_seqlens_q=attn_metadata.query_start_loc, + max_seqlen_q=attn_metadata.max_query_len, + cu_seqlens_k=attn_metadata.seq_start_loc, + max_seqlen_k=attn_metadata.max_seq_len, + softmax_scale=self.scale, + causal=True, + alibi_slopes=self.alibi_slopes, + window_size=self.sliding_window, + block_table=attn_metadata.block_table, + softcap=self.logits_soft_cap, + ) + + return output