diff --git a/flash_attn/flash_attn_triton_amd/fwd_prefill.py b/flash_attn/flash_attn_triton_amd/fwd_prefill.py index 4cbfb1729..955e0fc2c 100644 --- a/flash_attn/flash_attn_triton_amd/fwd_prefill.py +++ b/flash_attn/flash_attn_triton_amd/fwd_prefill.py @@ -1,7 +1,7 @@ import torch import triton import triton.language as tl -from .utils import DEBUG, DROPOUT_USE_PYTORCH, DROPOUT_DUMP, AUTOTUNE, MetaData, get_shape_from_layout, get_strides_from_layout, is_cdna, is_rdna, write_dropout_mask, create_dropout_mask +from .utils import DEBUG, DROPOUT_USE_PYTORCH, DROPOUT_DUMP, AUTOTUNE, get_shape_from_layout, get_strides_from_layout, is_cdna, is_rdna, write_dropout_mask, create_dropout_mask # NOTE: triton fails to import tl.constexprs so create them here for the file tl_DROPOUT_USE_PYTORCH: tl.constexpr = DROPOUT_USE_PYTORCH @@ -559,22 +559,6 @@ def attention_prefill_forward_triton_impl( return_softmax, use_exp2): - - if q.dtype in FP8_TYPES: - is_fp8 = True - q_scale = fp8_metadata.q_scale - k_scale = fp8_metadata.k_scale - v_scale = fp8_metadata.v_scale - p_scale = fp8_metadata.p_scale - p_inv_scale = fp8_metadata.p_inv_scale - q_scale_stride_z = q_scale.stride(0) - kv_scale_stride_z = k_scale.stride(0) - p_scale_stride_z = p_scale.stride(0) - p_inv_scale_stride_z = p_inv_scale.stride(0) - else: - q_scale = k_scale = v_scale = p_scale = p_inv_scale = 1 - q_scale_stride_z = kv_scale_stride_z = p_scale_stride_z = p_inv_scale_stride_z = 0 - if DEBUG: print() print("attention_prefill_forward_triton_impl") @@ -582,11 +566,6 @@ def attention_prefill_forward_triton_impl( print("k:", k, k.shape) print("v:", v, v.shape) print("o:", o, o.shape) - print("q_scale:", q_scale) - print("k_scale:", k_scale) - print("v_scale:", v_scale) - print("p_scale:", p_scale) - print("p_inv_scale:", p_inv_scale) print("sm_scale:", sm_scale) print("alibi_slopes:", alibi_slopes) print("causal:", causal) @@ -602,6 +581,60 @@ def attention_prefill_forward_triton_impl( print("return_scores:", return_softmax) print("use_exp2:", use_exp2) + # Define FP8 types we support + FP8_TYPES = {torch.float8_e4m3fnuz, torch.float8_e4m3fn, torch.float8_e5m2, torch.float8_e5m2fnuz} + + # Simple check if tensors are FP8 + is_fp8 = q.dtype in FP8_TYPES + + if is_fp8: + # Convert to float32 for scale computation + q_float32 = q.detach().to(torch.float32) + k_float32 = k.detach().to(torch.float32) + v_float32 = v.detach().to(torch.float32) + + # Get shapes for scaling + batch = q.size(0) if layout != "thd" else len(cu_seqlens_q) - 1 + nheads_q = q.size(1) if layout == "bhsd" else q.size(2) + nheads_k = k.size(1) if layout == "bhsd" else k.size(2) + + # Compute global max values + eps = 1e-9 + q_max = max(q_float32.abs().max().item(), eps) + k_max = max(k_float32.abs().max().item(), eps) + v_max = max(v_float32.abs().max().item(), eps) + + # Create scale tensors with the global values + q_scale = torch.full((batch, nheads_q), q_max, dtype=torch.float32, device=q.device) + k_scale = torch.full((batch, nheads_k), k_max, dtype=torch.float32, device=k.device) + v_scale = torch.full((batch, nheads_k), v_max, dtype=torch.float32, device=v.device) + + # Simple p_scale for softmax computation + p_scale = torch.full((batch, nheads_q), 1.0, dtype=torch.float32, device=q.device) + p_inv_scale = torch.full((batch, nheads_q), 1.0, dtype=torch.float32, device=q.device) + + # Get strides for the kernel + q_scale_stride_z = q_scale.stride(0) + kv_scale_stride_z = k_scale.stride(0) + p_scale_stride_z = p_scale.stride(0) + p_inv_scale_stride_z = p_inv_scale.stride(0) + else: + # For non-FP8 types, use dummy values (no scaling needed) + q_scale = k_scale = v_scale = p_scale = p_inv_scale = 1 + q_scale_stride_z = kv_scale_stride_z = p_scale_stride_z = p_inv_scale_stride_z = 0 + + if DEBUG: + print("is_fp8:", is_fp8) + print("q_scale:", q_scale) + print("k_scale:", k_scale) + print("v_scale:", v_scale) + print("p_scale:", p_scale) + print("p_inv_scale:", p_inv_scale) + print("q_scale_stride_z:", q_scale_stride_z) + print("kv_scale_stride_z:", kv_scale_stride_z) + print("p_scale_stride_z:", p_scale_stride_z) + print("p_inv_scale_stride_z:", p_inv_scale_stride_z) + # check if varlen is_varlen = layout == "thd"