Skip to content

Commit

Permalink
global scaling
Browse files Browse the repository at this point in the history
  • Loading branch information
micmelesse committed Jan 7, 2025
1 parent a290a6d commit 3542300
Showing 1 changed file with 55 additions and 22 deletions.
77 changes: 55 additions & 22 deletions flash_attn/flash_attn_triton_amd/fwd_prefill.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import torch
import triton
import triton.language as tl
from .utils import DEBUG, DROPOUT_USE_PYTORCH, DROPOUT_DUMP, AUTOTUNE, MetaData, get_shape_from_layout, get_strides_from_layout, is_cdna, is_rdna, write_dropout_mask, create_dropout_mask
from .utils import DEBUG, DROPOUT_USE_PYTORCH, DROPOUT_DUMP, AUTOTUNE, get_shape_from_layout, get_strides_from_layout, is_cdna, is_rdna, write_dropout_mask, create_dropout_mask

# NOTE: triton fails to import tl.constexprs so create them here for the file
tl_DROPOUT_USE_PYTORCH: tl.constexpr = DROPOUT_USE_PYTORCH
Expand Down Expand Up @@ -559,34 +559,13 @@ def attention_prefill_forward_triton_impl(
return_softmax,
use_exp2):


if q.dtype in FP8_TYPES:
is_fp8 = True
q_scale = fp8_metadata.q_scale
k_scale = fp8_metadata.k_scale
v_scale = fp8_metadata.v_scale
p_scale = fp8_metadata.p_scale
p_inv_scale = fp8_metadata.p_inv_scale
q_scale_stride_z = q_scale.stride(0)
kv_scale_stride_z = k_scale.stride(0)
p_scale_stride_z = p_scale.stride(0)
p_inv_scale_stride_z = p_inv_scale.stride(0)
else:
q_scale = k_scale = v_scale = p_scale = p_inv_scale = 1
q_scale_stride_z = kv_scale_stride_z = p_scale_stride_z = p_inv_scale_stride_z = 0

if DEBUG:
print()
print("attention_prefill_forward_triton_impl")
print("q:", q, q.shape)
print("k:", k, k.shape)
print("v:", v, v.shape)
print("o:", o, o.shape)
print("q_scale:", q_scale)
print("k_scale:", k_scale)
print("v_scale:", v_scale)
print("p_scale:", p_scale)
print("p_inv_scale:", p_inv_scale)
print("sm_scale:", sm_scale)
print("alibi_slopes:", alibi_slopes)
print("causal:", causal)
Expand All @@ -602,6 +581,60 @@ def attention_prefill_forward_triton_impl(
print("return_scores:", return_softmax)
print("use_exp2:", use_exp2)

# Define FP8 types we support
FP8_TYPES = {torch.float8_e4m3fnuz, torch.float8_e4m3fn, torch.float8_e5m2, torch.float8_e5m2fnuz}

# Simple check if tensors are FP8
is_fp8 = q.dtype in FP8_TYPES

if is_fp8:
# Convert to float32 for scale computation
q_float32 = q.detach().to(torch.float32)
k_float32 = k.detach().to(torch.float32)
v_float32 = v.detach().to(torch.float32)

# Get shapes for scaling
batch = q.size(0) if layout != "thd" else len(cu_seqlens_q) - 1
nheads_q = q.size(1) if layout == "bhsd" else q.size(2)
nheads_k = k.size(1) if layout == "bhsd" else k.size(2)

# Compute global max values
eps = 1e-9
q_max = max(q_float32.abs().max().item(), eps)
k_max = max(k_float32.abs().max().item(), eps)
v_max = max(v_float32.abs().max().item(), eps)

# Create scale tensors with the global values
q_scale = torch.full((batch, nheads_q), q_max, dtype=torch.float32, device=q.device)
k_scale = torch.full((batch, nheads_k), k_max, dtype=torch.float32, device=k.device)
v_scale = torch.full((batch, nheads_k), v_max, dtype=torch.float32, device=v.device)

# Simple p_scale for softmax computation
p_scale = torch.full((batch, nheads_q), 1.0, dtype=torch.float32, device=q.device)
p_inv_scale = torch.full((batch, nheads_q), 1.0, dtype=torch.float32, device=q.device)

# Get strides for the kernel
q_scale_stride_z = q_scale.stride(0)
kv_scale_stride_z = k_scale.stride(0)
p_scale_stride_z = p_scale.stride(0)
p_inv_scale_stride_z = p_inv_scale.stride(0)
else:
# For non-FP8 types, use dummy values (no scaling needed)
q_scale = k_scale = v_scale = p_scale = p_inv_scale = 1
q_scale_stride_z = kv_scale_stride_z = p_scale_stride_z = p_inv_scale_stride_z = 0

if DEBUG:
print("is_fp8:", is_fp8)
print("q_scale:", q_scale)
print("k_scale:", k_scale)
print("v_scale:", v_scale)
print("p_scale:", p_scale)
print("p_inv_scale:", p_inv_scale)
print("q_scale_stride_z:", q_scale_stride_z)
print("kv_scale_stride_z:", kv_scale_stride_z)
print("p_scale_stride_z:", p_scale_stride_z)
print("p_inv_scale_stride_z:", p_inv_scale_stride_z)

# check if varlen
is_varlen = layout == "thd"

Expand Down

0 comments on commit 3542300

Please sign in to comment.