diff --git a/csrc/activation_kernels.cu b/csrc/activation_kernels.cu index 88275dbdd83a1..9609bcb5e2b7b 100644 --- a/csrc/activation_kernels.cu +++ b/csrc/activation_kernels.cu @@ -7,6 +7,10 @@ #include "cuda_compat.h" #include "dispatch_utils.h" +#ifdef USE_ROCM + #include "quantization/fp8/amd/hip_float8.h" +#endif + namespace vllm { template -__device__ __forceinline__ T load(T* addr) { - return addr[0]; +__device__ __forceinline__ T loadnt(T* addr) { + return __builtin_nontemporal_load(addr); } -template -__device__ __forceinline__ void store(T value, T* addr) { - addr[0] = value; +__device__ __forceinline__ _B16x8 load_ntmprl_16Byte(const _B16x8* addr) { + auto addr_alias = reinterpret_cast(addr); + auto dat0 = loadnt(addr_alias); + auto dat1 = loadnt(addr_alias + 1); + auto dat2 = loadnt(addr_alias + 2); + auto dat3 = loadnt(addr_alias + 3); + auto res = make_float4(dat0, dat1, dat2, dat3); + return *reinterpret_cast<_B16x8*>(&res); } +/////////////////////////////////// template -__device__ __forceinline__ floatx4 gcn_mfma_instr(const _B16x4& inpA, - const _B16x4& inpB, - const floatx4& inpC) { +__device__ __forceinline__ floatx4 gcn_mfma4x4x4_instr(const _B16x4& inpA, + const _B16x4& inpB, + const floatx4& inpC) { if constexpr (std::is_same::value) { return __builtin_amdgcn_mfma_f32_4x4x4f16(inpA, inpB, inpC, absz, cbid, blgp); @@ -91,6 +105,21 @@ __device__ __forceinline__ floatx4 gcn_mfma_instr(const _B16x4& inpA, } } +template +__device__ __forceinline__ floatx4 gcn_mfma16x16x16_instr(const _B16x4& inpA, + const _B16x4& inpB, + const floatx4& inpC) { + if constexpr (std::is_same::value) { + return __builtin_amdgcn_mfma_f32_16x16x16f16(inpA, inpB, inpC, absz, cbid, + blgp); + } else if constexpr (std::is_same::value) { + return __builtin_amdgcn_mfma_f32_16x16x16bf16_1k(inpA, inpB, inpC, absz, + cbid, blgp); + } else { + static_assert(false, "unsupported 16b dtype"); + } +} + template __device__ __forceinline__ float to_float(const T& inp) { if constexpr (std::is_same::value) { @@ -102,6 +131,23 @@ __device__ __forceinline__ float to_float(const T& inp) { } } +template +__device__ __forceinline__ float to_float_b16(const bit16_t& inp) { + union tmpcvt { + bit16_t u; + _Float16 f; + __hip_bfloat16 b; + } t16; + t16.u = inp; + if constexpr (std::is_same::value) { + return (float)t16.f; + } else if constexpr (std::is_same::value) { + return __bfloat162float(t16.b); + } else { + static_assert(false, "unsupported 16b dtype"); + } +} + template __device__ __forceinline__ T from_float(const float& inp) { if constexpr (std::is_same::value) { @@ -122,17 +168,22 @@ __device__ __forceinline__ _B16x4 from_floatx4(const floatx4& inp) { } t16; _B16x4 ret; if constexpr (std::is_same::value) { - #pragma unroll - for (int i = 0; i < 4; i++) { - t16.f = (_Float16)inp[i]; - ret[i] = t16.u; - } - return ret; + union h2cvt { + __half2 h2[2]; + _B16x4 b16x4; + } u; + u.h2[0] = __float22half2_rn(make_float2(inp[0], inp[1])); + u.h2[1] = __float22half2_rn(make_float2(inp[2], inp[3])); + return u.b16x4; } else if constexpr (std::is_same::value) { - #pragma unroll for (int i = 0; i < 4; i++) { - t16.b = __float2bfloat16(inp[i]); - ret[i] = t16.u; + union fcvt { + uint32_t u32; + float f32; + } u; + u.f32 = inp[i]; + u.u32 += 0x7fff + ((u.u32 >> 16) & 1); // BF16 RNE with no nan/inf check + ret[i] = uint16_t(u.u32 >> 16); } return ret; } else { @@ -150,21 +201,25 @@ __device__ __forceinline__ _B16x4 addx4(const _B16x4& inp1, } t1, t2, res; _B16x4 ret; if constexpr (std::is_same::value) { - #pragma unroll - for (int i = 0; i < 4; i++) { - t1.u = inp1[i]; - t2.u = inp2[i]; - res.f = t1.f + t2.f; - ret[i] = res.u; - } - return ret; + union h2cvt { + _B16x4 b16x4; + __half2 h2[2]; + } u1, u2, s; + u1.b16x4 = inp1; + u2.b16x4 = inp2; + s.h2[0] = u1.h2[0] + u2.h2[0]; + s.h2[1] = u1.h2[1] + u2.h2[1]; + return s.b16x4; } else if constexpr (std::is_same::value) { - #pragma unroll for (int i = 0; i < 4; i++) { - t1.u = inp1[i]; - t2.u = inp2[i]; - res.b = t1.b + t2.b; - ret[i] = res.u; + union fcvt { + float f32; + uint32_t i32; + } u1, u2, s; + u1.i32 = uint32_t(inp1[i]) << 16; + u2.i32 = uint32_t(inp2[i]) << 16; + s.f32 = u1.f32 + u2.f32; + ret[i] = uint16_t(s.i32 >> 16); } return ret; } else { @@ -192,15 +247,599 @@ __device__ __forceinline__ _B16x8 scaled_convert_b8x8(const _B8x8 input, } } +template +__device__ __forceinline__ _B16x8 +scaled_convert_b8x8_custom(const _B8x8 input, const float scale) { + union { + floatx4 f32x4[2]; + vllm::Float8_ f32x8; + } tmpf8; + tmpf8.f32x8 = vllm::fp8::vec_conversion( + *reinterpret_cast(&input)); + + tmpf8.f32x4[0] *= scale; + tmpf8.f32x4[1] *= scale; + + _B16x8 ret; + ret.xy[0] = from_floatx4(tmpf8.f32x4[0]); + ret.xy[1] = from_floatx4(tmpf8.f32x4[1]); + return ret; +} + +__device__ __forceinline__ floatx4 to_float_fp8x4(const _B8x4& inp) { + #if defined(__gfx90a__) + float4 f32x4 = vllm::fp8::vec_conversion( + *reinterpret_cast(&inp)); + return *reinterpret_cast(&f32x4); + #else // MI3xx+ optimized builtins + const auto f0 = __builtin_amdgcn_cvt_pk_f32_fp8(inp, false); + const auto f1 = __builtin_amdgcn_cvt_pk_f32_fp8(inp, true); + floatx4 ret; + ret[0] = f0[0]; + ret[1] = f0[1]; + ret[2] = f1[0]; + ret[3] = f1[1]; + return ret; + #endif +} + +template +__device__ __forceinline__ _B16x4 from_floatx4_rtz(const floatx4& inp) { + _B16x4 ret; + if constexpr (std::is_same::value) { + union h2cvt { + _Half2 h2[2]; + _B16x4 b16x4; + } u; + u.h2[0] = __builtin_amdgcn_cvt_pkrtz(inp[0], inp[1]); + u.h2[1] = __builtin_amdgcn_cvt_pkrtz(inp[2], inp[3]); + return u.b16x4; + } else if constexpr (std::is_same::value) { + for (int i = 0; i < 4; i++) { + union fcvt { + uint32_t i32; + float f32; + } u; + u.f32 = inp[i]; + ret[i] = uint16_t(u.i32 >> 16); + } + return ret; + } else { + static_assert(false, "unsupported 16b dtype"); + } +} + +template +__device__ __forceinline__ _B16x8 convert_b8x8_custom(const _B8x8 input) { + union { + _B8x8 b8x8; + _B8x4 b8x4[2]; + } tmp; + tmp.b8x8 = input; + _B16x8 ret; + for (int i = 0; i < 2; i++) { + ret.xy[i] = from_floatx4_rtz(to_float_fp8x4(tmp.b8x4[i])); + } + return ret; +} + /////////////////////////////////////// +// grid (num_seqs, num_partitions,num_kv_heads) +// block (256) +template +__global__ +__launch_bounds__(NUM_THREADS, 5) void paged_attention_ll4mi_QKV_mfma16_kernel( + const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size] + const cache_t* __restrict__ k_cache, // [num_blocks, num_kv_heads, + // head_size/x, block_size, x] + const cache_t* __restrict__ v_cache, // [num_blocks, num_kv_heads, + // head_size, block_size] + const int num_kv_heads, const float scale, + const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq] + const int* __restrict__ context_lens, // [num_seqs] + const int max_num_blocks_per_seq, + const float* __restrict__ alibi_slopes, // [num_heads] + const int q_stride, const int kv_block_stride, const int kv_head_stride, + float* __restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions] + float* __restrict__ max_logits, // [num_seqs, num_heads, + // max_num_partitions] + scalar_t* __restrict__ out, // [num_seqs, num_heads, max_num_partitions, + // head_size] + OUTT* __restrict__ final_out, // [num_seqs, num_heads, head_size] + int max_ctx_blocks, const float* k_scale_ptr, const float* v_scale_ptr) { + constexpr int NWARPS = NUM_THREADS / WARP_SIZE; + const int warpid = threadIdx.x / WARP_SIZE; + const int laneid = threadIdx.x % WARP_SIZE; + const int lane4id = laneid % 4; + const int lane16id = laneid % 16; + const int rowid = laneid / 16; + + const int seq_idx = blockIdx.x; + const int partition_idx = blockIdx.y; + + constexpr int T_PAR_SIZE = 256; // token partition size set to 256 + + const int max_num_partitions = gridDim.y; + + const int context_len = context_lens[seq_idx]; -// grid (num_seqs, num_partitions,num_heads/gqa_ratio) -// block (partition size) + const int partition_start_token_idx = + partition_idx * T_PAR_SIZE; // partition_size; + // exit if partition is out of context for seq + if (partition_start_token_idx >= context_len) { + return; + } + + constexpr int GQA_RATIO4 = DIVIDE_ROUND_UP(GQA_RATIO, 4); + + __shared__ float shared_qk_max[NWARPS][16 + 1]; + __shared__ float shared_exp_sum[NWARPS][16 + 1]; + // shared_logits is used for multiple purposes + __shared__ _B16x4 shared_logits[NWARPS][4][16][4]; + + // for QK mfma16x16, layout is QHead/Tokenx16 across every 16 lanes, 16 Bytes + // HeadElements in each lane, 4x16B HeadElements across 4 rows of warp + constexpr int ROWS_PER_WARP = + WARP_SIZE / 16; // rows refers to 16 lanes; refer dpp terminology + constexpr int CONTIGUOUS_KV_ELEMS_16B_LOAD = + 16 / sizeof(cache_t); // 8 for 16 bit cache type, 16 for 8 bit types + constexpr int QKHE_PER_FETCH = + CONTIGUOUS_KV_ELEMS_16B_LOAD * + ROWS_PER_WARP; // each fetch across a warp fetches these many elements + constexpr int QK_SIZE_RATIO = + sizeof(scalar_t) / + sizeof(cache_t); // 1 for 16bit types, 2 for 8bit types + constexpr int QKHELOOP = HEAD_SIZE / QKHE_PER_FETCH; // 4xQKHE_16B across + // warp + + _B16x8 Qlocal[QKHELOOP] + [QK_SIZE_RATIO]; // note that 16 contiguous elements of Q should + // be fetched per lane for 8 bit cache types : + // QK_SIZE_RATIO changes for this + + constexpr int CONTIGUOUS_SCALAR_ELEMS_16B = 16 / sizeof(scalar_t); + + constexpr int TOKENS_PER_WARP = + T_PAR_SIZE / + NWARPS; // sub partition of tokens per warp for qk calculation + constexpr int TLOOP = + TOKENS_PER_WARP / + 16; // each mfma16x16x16 instruction processes 16 tokens + + _B16x8 Klocal[TLOOP][QKHELOOP]; // can be interpreted as B8x16 for 8 bit + // types + + const int wg_start_head_idx = blockIdx.z * GQA_RATIO; + const int wg_start_kv_head_idx = blockIdx.z; + const int total_num_heads = gridDim.z * GQA_RATIO; + + // for QK mfma, tokens in multiples of TOKENS_PER_WARP are spread across warps + // each mfma takes QH16xT16x16HE across warp + // repeat mfmas across QKHELOOP dimension + // output layout from QKmfma : QH16xT4x4 16 qheads across 16 lanes, 16 tokens + // across 4 rows x 4 tokens per lane + + const int num_context_blocks = DIVIDE_ROUND_UP(context_len, BLOCK_SIZE); + const int last_ctx_block = num_context_blocks - 1; + + const int* block_table_seq = block_tables + seq_idx * max_num_blocks_per_seq; + + int kphysical_block_number[TLOOP]; + + // fetch k physical block numbers + for (int token_depth = 0; token_depth < TLOOP; token_depth++) { + const int klocal_token_idx = + TOKENS_PER_WARP * warpid + token_depth * 16 + lane16id; + const int kglobal_token_idx = partition_start_token_idx + klocal_token_idx; + const int kblock_idx = (kglobal_token_idx < context_len) + ? kglobal_token_idx / BLOCK_SIZE + : last_ctx_block; + kphysical_block_number[token_depth] = block_table_seq[kblock_idx]; + } + + // fetch Q in shared across warps and then write to registers + const int local_qhead_idx = 4 * warpid + rowid; + const int global_qhead_idx = wg_start_head_idx + local_qhead_idx; + const int64_t seq_idx64 = static_cast(seq_idx); + const scalar_t* q_ptr = + q + seq_idx64 * q_stride + global_qhead_idx * HEAD_SIZE; + + const int qhead_element = lane16id * CONTIGUOUS_SCALAR_ELEMS_16B; + if ((local_qhead_idx < GQA_RATIO) && (qhead_element < HEAD_SIZE)) { + const scalar_t* q_fetch_ptr = q_ptr + qhead_element; + const _B16x8* q_fetch_ptr_16B = + reinterpret_cast(q_fetch_ptr); + _B16x8 tmp = *q_fetch_ptr_16B; + if constexpr (KV_DTYPE == vllm::Fp8KVCacheDataType::kAuto) { + const int offset1 = + lane16id / + 4; // 16 contiguous chunks of head elems are spread across 4x4lanes + shared_logits[offset1][lane4id][local_qhead_idx][0] = tmp.xy[0]; + shared_logits[offset1][lane4id][local_qhead_idx][1] = tmp.xy[1]; + } else { + for (int i = 0; i < 2; i++) { + const int head_elem = lane16id * 2 + i; // element id in _B16x4 terms + const int offset3 = head_elem % 4; + const int offset2 = (head_elem / 4) % 4; + const int offset1 = head_elem / 4 / 4; + shared_logits[offset1][offset2][local_qhead_idx][offset3] = tmp.xy[i]; + } + } + } + __syncthreads(); + for (int qkhe_depth = 0; qkhe_depth < QKHELOOP; qkhe_depth++) { + for (int qkratio = 0; qkratio < QK_SIZE_RATIO; qkratio++) { + for (int i = 0; i < 2; i++) { + Qlocal[qkhe_depth][qkratio].xy[i] = + shared_logits[qkhe_depth][rowid][lane16id % GQA_RATIO] + [2 * qkratio + i]; + } + } + } + + // set to true to enable non temporal kv loads: has some benefit in very high + // batch size cases + constexpr bool NT_KV_LOAD = false; + + constexpr int KX = + 16 / sizeof(cache_t); // vLLM defines x as 16 Bytes of kv cache elements + const cache_t* k_ptr = k_cache + wg_start_kv_head_idx * kv_head_stride; + + const int row_head_elem = rowid * CONTIGUOUS_KV_ELEMS_16B_LOAD; + // fetch K values + for (int token_depth = 0; token_depth < TLOOP; token_depth++) { + const int64_t kblock_number = + static_cast(kphysical_block_number[token_depth]); + const cache_t* k_ptr2 = k_ptr + kblock_number * kv_block_stride; + const int klocal_token_idx = + TOKENS_PER_WARP * warpid + token_depth * 16 + lane16id; + const int kglobal_token_idx = partition_start_token_idx + klocal_token_idx; + const int kphysical_block_offset = klocal_token_idx % BLOCK_SIZE; + const cache_t* k_ptr3 = k_ptr2 + kphysical_block_offset * KX; + + for (int qkhe_depth = 0; qkhe_depth < QKHELOOP; qkhe_depth++) { + const int head_elem = row_head_elem + qkhe_depth * QKHE_PER_FETCH; + const int offset1 = head_elem / KX; + const int offset2 = head_elem % KX; + const cache_t* k_fetch_ptr = k_ptr3 + offset1 * BLOCK_SIZE * KX + offset2; + const _B16x8* k_fetch_ptr_16B = + reinterpret_cast(k_fetch_ptr); + if constexpr (NT_KV_LOAD) { + Klocal[token_depth][qkhe_depth] = load_ntmprl_16Byte(k_fetch_ptr_16B); + } else { + Klocal[token_depth][qkhe_depth] = *k_fetch_ptr_16B; + } + } + } + + float alibi_slope; + if constexpr (ALIBI_ENABLED) { + const int alibi_head_idx = wg_start_head_idx + lane16id; + alibi_slope = (lane16id < GQA_RATIO) ? alibi_slopes[alibi_head_idx] : 0.f; + } + + constexpr int VTOKENS_PER_LANE = + TOKENS_PER_WARP / ROWS_PER_WARP; // 64/4 = 16 contiguous vtokens per lane + constexpr int VBLOCKS_PER_LANE = + 1; // assumes block size >=16, each lane can correspond to 1 block only + constexpr int VTLOOP = NWARPS; // corresponds to tokens across warps + constexpr int VTLANELOOP = DIVIDE_ROUND_UP( + VTOKENS_PER_LANE, + CONTIGUOUS_KV_ELEMS_16B_LOAD); // optimized for 16B fetches; assumes + // minimum block size is 16 + constexpr int VHELOOP = + HEAD_SIZE / 16 / NWARPS; // head_size distributed across warps; each mfma + // instr works on 16 head elements + + int vphysical_block_number[VTLOOP][VBLOCKS_PER_LANE]; + + // fetch v physical block numbers + for (int vtoken_depth = 0; vtoken_depth < VTLOOP; vtoken_depth++) { + for (int vblock_depth = 0; vblock_depth < VBLOCKS_PER_LANE; + vblock_depth++) { + const int vlocal_token_idx = + vtoken_depth * VTOKENS_PER_LANE * ROWS_PER_WARP + + rowid * VTOKENS_PER_LANE + vblock_depth * BLOCK_SIZE; + const int vglobal_token_idx = + partition_start_token_idx + vlocal_token_idx; + const int vblock_idx = (vglobal_token_idx < context_len) + ? vglobal_token_idx / BLOCK_SIZE + : last_ctx_block; + vphysical_block_number[vtoken_depth][vblock_depth] = + block_table_seq[vblock_idx]; + } + } + + _B16x8 Vlocal[VTLOOP][VHELOOP] + [VTLANELOOP]; // this can be interpreted as B8x16 too + + const cache_t* v_ptr = v_cache + wg_start_kv_head_idx * kv_head_stride + + ((rowid * VTOKENS_PER_LANE) % BLOCK_SIZE); + + // v fetches are 16head elems across lanes x 16 tokens per lane + for (int vhe_depth = 0; vhe_depth < VHELOOP; vhe_depth++) { + const int vhead_elem = vhe_depth * NWARPS * 16 + warpid * 16 + lane16id; + const cache_t* v_ptr2 = v_ptr + vhead_elem * BLOCK_SIZE; + + for (int vtoken_depth = 0; vtoken_depth < VTLOOP; vtoken_depth++) { + for (int vfetch_depth = 0; vfetch_depth < VTLANELOOP; vfetch_depth++) { + const int vblock_depth = 0; + const int64_t vblock_number = static_cast( + vphysical_block_number[vtoken_depth][vblock_depth]); + const cache_t* v_ptr3 = v_ptr2 + (vblock_number * kv_block_stride); + + const cache_t* v_fetch_ptr = + v_ptr3 + vfetch_depth * CONTIGUOUS_KV_ELEMS_16B_LOAD; + const _B16x8* v_fetch_ptr_16B = + reinterpret_cast(v_fetch_ptr); + if constexpr (NT_KV_LOAD) { + Vlocal[vtoken_depth][vhe_depth][vfetch_depth] = + load_ntmprl_16Byte(v_fetch_ptr_16B); + } else { + Vlocal[vtoken_depth][vhe_depth][vfetch_depth] = *v_fetch_ptr_16B; + } + } + } + } + + // calculate post qk mfma scale + float scale2 = scale; + if constexpr (KV_DTYPE != vllm::Fp8KVCacheDataType::kAuto) { + // multiply by k_scale if fp8 kv cache + scale2 *= *k_scale_ptr; + } + + floatx4 dout[TLOOP]; + // qk mfma + for (int token_depth = 0; token_depth < TLOOP; token_depth++) { + dout[token_depth] = {0}; + for (int qkhe_depth = 0; qkhe_depth < QKHELOOP; qkhe_depth++) { + if constexpr (KV_DTYPE == vllm::Fp8KVCacheDataType::kAuto) { + for (int qkratio = 0; qkratio < QK_SIZE_RATIO; qkratio++) { + for (int i = 0; i < 2; i++) { + dout[token_depth] = gcn_mfma16x16x16_instr( + Klocal[token_depth][qkhe_depth].xy[i], + Qlocal[qkhe_depth][qkratio].xy[i], dout[token_depth]); + } + } + } else { // kv cache dtype fp8 + auto Ktmp = Klocal[token_depth][qkhe_depth]; + _B8x16 Ktmp8x16 = *reinterpret_cast<_B8x16*>(&Ktmp); + for (int qkratio = 0; qkratio < QK_SIZE_RATIO; qkratio++) { + _B8x8 Ktmp8x8 = Ktmp8x16.xy[qkratio]; + _B16x8 Klocaltmp = convert_b8x8_custom(Ktmp8x8); + for (int i = 0; i < 2; i++) { + dout[token_depth] = gcn_mfma16x16x16_instr( + Klocaltmp.xy[i], Qlocal[qkhe_depth][qkratio].xy[i], + dout[token_depth]); + } + } + } + } + dout[token_depth] *= scale2; + } + + const int qkout_token_idx = + partition_start_token_idx + TOKENS_PER_WARP * warpid + rowid * 4; + + // apply alibi + if constexpr (ALIBI_ENABLED) { + for (int token_depth = 0; token_depth < TLOOP; token_depth++) { + const int local_token_idx = qkout_token_idx + token_depth * 16; + const int alibi_offset = local_token_idx - context_len + 1; + for (int i = 0; i < 4; i++) { + dout[token_depth][i] += alibi_slope * (alibi_offset + i); + } + } + } + + // calculate qk_max and exp_sum per warp and write to shared memory + float qk_max = -FLT_MAX; + float exp_sum = 0.0f; + + for (int token_depth = 0; token_depth < TLOOP; token_depth++) { + const int local_token_idx = qkout_token_idx + token_depth * 16; + for (int i = 0; i < 4; i++) { + const float tmp = + (local_token_idx + i < context_len) ? dout[token_depth][i] : -FLT_MAX; + qk_max = fmaxf(qk_max, tmp); + } + } + + for (int mask = WARP_SIZE / 2; mask >= 16; mask /= 2) { + qk_max = fmaxf(qk_max, __shfl_xor(qk_max, mask)); + } + + for (int token_depth = 0; token_depth < TLOOP; token_depth++) { + const int local_token_idx = qkout_token_idx + token_depth * 16; + for (int i = 0; i < 4; i++) { + const float tmp = (local_token_idx + i < context_len) + ? __expf(dout[token_depth][i] - qk_max) + : 0.0f; + dout[token_depth][i] = tmp; + exp_sum += tmp; + } + } + + for (int mask = WARP_SIZE / 2; mask >= 16; mask /= 2) { + exp_sum += __shfl_xor(exp_sum, mask); + } + + __syncthreads(); // sync before writing to shared mem + + float* shared_mem = reinterpret_cast(shared_logits); + if (laneid < 16) { + const int qk_max_offset = warpid * 16 + lane16id; + shared_mem[qk_max_offset] = qk_max; + const int exp_sum_offset = NWARPS * 16 + qk_max_offset; + shared_mem[exp_sum_offset] = exp_sum; + } + + __syncthreads(); + + // calculate partition qk_max and exp_sum + float partition_qk_max = -FLT_MAX; + float warp_qk_max_exp[NWARPS]; + float partition_exp_sum = 0.0f; + + for (int w = 0; w < NWARPS; w++) { + warp_qk_max_exp[w] = shared_mem[w * 16 + lane16id]; + partition_qk_max = fmaxf(partition_qk_max, warp_qk_max_exp[w]); + } + + for (int w = 0; w < NWARPS; w++) { + warp_qk_max_exp[w] = __expf(warp_qk_max_exp[w] - partition_qk_max); + partition_exp_sum += + shared_mem[NWARPS * 16 + w * 16 + lane16id] * warp_qk_max_exp[w]; + } + + const float inv_sum_scale = + __fdividef(1.f, partition_exp_sum + 1e-6f) * warp_qk_max_exp[warpid]; + + __syncthreads(); + + // disable rtz conversion due to its impact on accuracy. + constexpr bool LOGITS_RTZ_CONVERSION = false; + + // write logits to shared mem + for (int token_depth = 0; token_depth < TLOOP; token_depth++) { + dout[token_depth] *= inv_sum_scale; + if constexpr (LOGITS_RTZ_CONVERSION) { + // use rtz conversion for better performance, with negligible impact on + // accuracy. + shared_logits[warpid][token_depth][lane16id][rowid] = + from_floatx4_rtz(dout[token_depth]); + } else { + shared_logits[warpid][token_depth][lane16id][rowid] = + from_floatx4(dout[token_depth]); + } + } + // write out partition max_logits and exp_sum + if (threadIdx.x < GQA_RATIO) { + const int qhead_idx = lane16id; + const int offset = seq_idx * total_num_heads * max_num_partitions + + (wg_start_head_idx + qhead_idx) * max_num_partitions + + partition_idx; + max_logits[offset] = partition_qk_max; + exp_sums[offset] = partition_exp_sum; + } + + __syncthreads(); + + constexpr int ELEMS8_ELEMS4_RATIO = 8 / 4; + constexpr int ELEMS16_ELEMS8_RATIO = 16 / 8; + + _B16x4 outelems[VHELOOP]; + // Softmax V mfma + // v layout: 16he across lanes x 16 tokens per lane + for (int vhe_depth = 0; vhe_depth < VHELOOP; vhe_depth++) { + floatx4 tmp_out = {0}; + + for (int vtoken_depth = 0; vtoken_depth < VTLOOP; vtoken_depth++) { + if constexpr (KV_DTYPE == vllm::Fp8KVCacheDataType::kAuto) { + for (int vfetch_depth = 0; vfetch_depth < VTLANELOOP; vfetch_depth++) { + for (int i = 0; i < ELEMS8_ELEMS4_RATIO; i++) { + const int offset = rowid * VTLANELOOP * ELEMS8_ELEMS4_RATIO + + vfetch_depth * ELEMS8_ELEMS4_RATIO + i; + const int offset1 = offset % ROWS_PER_WARP; + const int offset2 = offset / ROWS_PER_WARP; + // output format is 16 qheads across 16 lanes, 16 head elems spread + // across 4 rows + tmp_out = gcn_mfma16x16x16_instr( + Vlocal[vtoken_depth][vhe_depth][vfetch_depth].xy[i], + shared_logits[vtoken_depth][offset2][lane16id][offset1], + tmp_out); + } + } + // KV cache fp8 + } else { + for (int vfetch_depth = 0; vfetch_depth < VTLANELOOP; vfetch_depth++) { + _B16x8 Vtmp = Vlocal[vtoken_depth][vhe_depth][vfetch_depth]; + // reinterpret V format as 16 elements of 8bits + _B8x16 Vtmp8x16 = *reinterpret_cast<_B8x16*>(&Vtmp); + for (int j = 0; j < ELEMS16_ELEMS8_RATIO; j++) { + _B8x8 Vtmp8x8 = Vtmp8x16.xy[j]; + _B16x8 Vlocaltmp = convert_b8x8_custom(Vtmp8x8); + for (int i = 0; i < ELEMS8_ELEMS4_RATIO; i++) { + const int offset = + rowid * ELEMS16_ELEMS8_RATIO * ELEMS8_ELEMS4_RATIO + + j * ELEMS8_ELEMS4_RATIO + i; + const int offset1 = offset % ROWS_PER_WARP; + const int offset2 = offset / ROWS_PER_WARP; + // output format is 16 qheads across 16 lanes, 16 head elems + // spread across 4 rows + tmp_out = gcn_mfma16x16x16_instr( + Vlocaltmp.xy[i], + shared_logits[vtoken_depth][offset2][lane16id][offset1], + tmp_out); + } + } + } + } + } + // apply post Softmax V mfma v_scale + if constexpr (KV_DTYPE != vllm::Fp8KVCacheDataType::kAuto) { + tmp_out *= *v_scale_ptr; + } + outelems[vhe_depth] = from_floatx4(tmp_out); + } + + __syncthreads(); + + // store Softmax-V mfma output to shared mem + for (int vhe_depth = 0; vhe_depth < VHELOOP; vhe_depth++) { + // lane16 id head dimension; rowid head element dimension + shared_logits[warpid][vhe_depth][lane16id][rowid] = outelems[vhe_depth]; + } + + __syncthreads(); + + // write to tmp_out with coalesced writes after reading from shared mem + if (warpid == 0) { + _B16x8 vout[GQA_RATIO4]; + // each lane writes out 16Bytes of tmp_out along head elem dimension + const int head_elem_idx = lane16id * 8; + if (head_elem_idx < HEAD_SIZE) { + for (int h = 0; h < GQA_RATIO4; h++) { + const int local_head_idx = 4 * h + rowid; + const int offset1 = (head_elem_idx / 16) % 4; + const int offset2 = head_elem_idx / 16 / NWARPS; + const int offset3 = (head_elem_idx / 4) % 4; + for (int i = 0; i < 2; i++) { + vout[h].xy[i] = + shared_logits[offset1][offset2][local_head_idx][offset3 + i]; + } + } + + const int hsz_maxp_mult = HEAD_SIZE * max_num_partitions; + scalar_t* out_ptr = out + seq_idx * total_num_heads * hsz_maxp_mult + + partition_idx * HEAD_SIZE; + for (int h = 0; h < GQA_RATIO4; h++) { + const int local_head_idx = 4 * h + rowid; + if (local_head_idx < GQA_RATIO) { + const int out_head_idx = wg_start_head_idx + local_head_idx; + scalar_t* out_ptr2 = out_ptr + out_head_idx * hsz_maxp_mult; + scalar_t* out_ptr3 = out_ptr2 + head_elem_idx; + _B16x8* out_ptr_B16x8 = reinterpret_cast<_B16x8*>(out_ptr3); + *out_ptr_B16x8 = vout[h]; + } + } + } + } +} + +///////////////////////////////////////////////////////////// +// grid (num_seqs, num_partitions, num_kv_heads) +// block (256 : partition size) +// each WG handles 1 partition per sequence template -__global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_kernel( +__global__ +__launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_mfma4_kernel( const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size] const cache_t* __restrict__ k_cache, // [num_blocks, num_kv_heads, // head_size/x, block_size, x] @@ -215,9 +854,9 @@ __global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_kernel( float* __restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions] float* __restrict__ max_logits, // [num_seqs, num_heads, // max_num_partitions] - scalar_t* __restrict__ out, // [num_seqs, num_heads, max_num_partitions, - // head_size] - scalar_t* __restrict__ final_out, // [num_seqs, num_heads, head_size] + scalar_t* __restrict__ out, // [num_seqs, num_heads, max_num_partitions, + // head_size] + OUTT* __restrict__ final_out, // [num_seqs, num_heads, head_size] int max_ctx_blocks, const float* k_scale_ptr, const float* v_scale_ptr) { constexpr int NWARPS = NUM_THREADS / WARP_SIZE; const int warpid = threadIdx.x / WARP_SIZE; @@ -235,27 +874,35 @@ __global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_kernel( if (partition_start_token_idx >= context_len) { return; } - constexpr int QHLOOP = - DIVIDE_ROUND_UP(GQA_RATIO, 4); // each 4 lanes fetch 4 different qheads, - // total qheads =8, so qhloop is 2 + // every 4 lanes fetch 4 different qheads + // qhloop = num loops over qhead dimension + constexpr int QHLOOP = DIVIDE_ROUND_UP(GQA_RATIO, 4); constexpr int GQA_RATIO4 = 4 * QHLOOP; __shared__ float shared_qk_max[NWARPS][GQA_RATIO4 + 1]; __shared__ float shared_exp_sum[NWARPS][GQA_RATIO4 + 1]; _B16x8 Qlocal[QHLOOP]; constexpr int x = 16 / sizeof(scalar_t); + // kheloop = num loops over head_size for 16Bytes of Q/dequantized K elements constexpr int KHELOOP = HEAD_SIZE / x; _B16x8 Klocal[KHELOOP]; _B8x8 Klocalb8[KHELOOP]; - constexpr int VHELOOP = - HEAD_SIZE / - WARP_SIZE; // v head_size dimension is distributed across lanes - constexpr int VTLOOP = 8; // 16 separate 4xtokens across warp -> 16/2 - // 8xtokens + // for SoftMax-V Gemm, V head_size dimension is distributed across warp + // vheloop = num loops to cover v head size dimension + constexpr int VHELOOP = HEAD_SIZE / WARP_SIZE; + // softmax out has warp_size tokens across warp + // vtloop = num loops to cover warp_size(64) tokens with 16Bytes of + // dequantized V elements + constexpr int VTLOOP = WARP_SIZE / 8; + // num vblocks to cover warp_size(64) v elements + constexpr int VBLOCKS = 8 * VTLOOP / BLOCK_SIZE; + int vphysical_blocks[VBLOCKS]; _B16x8 Vlocal[VHELOOP][VTLOOP]; _B8x8 Vlocalb8[VHELOOP][VTLOOP]; floatx4 dout[QHLOOP]; float qk_max[QHLOOP]; - #pragma unroll + + __shared__ _B16x4 vout_shared[QHLOOP][VHELOOP][WARP_SIZE][NWARPS + 1]; + for (int h = 0; h < QHLOOP; h++) { dout[h] = {0}; qk_max[h] = -FLT_MAX; @@ -267,37 +914,37 @@ __global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_kernel( const int warp_start_token_idx = partition_start_token_idx + warpid * WARP_SIZE; - if (warp_start_token_idx >= context_len) { // warp out of context + // entire warp out of context + if (warp_start_token_idx >= context_len) { #pragma unroll for (int h = 0; h < GQA_RATIO4; h++) { shared_qk_max[warpid][h] = -FLT_MAX; shared_exp_sum[warpid][h] = 0.0f; } - } else { // warp within context - + // warp within context + } else { const int num_context_blocks = DIVIDE_ROUND_UP(context_len, BLOCK_SIZE); const int last_ctx_block = num_context_blocks - 1; const int* block_table = block_tables + seq_idx * max_num_blocks_per_seq; - + // token id within partition const int local_token_idx = threadIdx.x; + // token id within sequence const int global_token_idx = partition_start_token_idx + local_token_idx; + // fetch block number for k const int block_idx = (global_token_idx < context_len) ? global_token_idx / BLOCK_SIZE : last_ctx_block; - // fetch block number for q and k - // int32 physical_block_number leads to overflow when multiplied with - // kv_block_stride + + // fetch k physical block number + // int32 physical_block_number leads to overflow when multiplied with + // kv_block_stride const int64_t physical_block_number = static_cast(block_table[block_idx]); // fetch vphysical block numbers up front - constexpr int VBLOCKS = 8 * VTLOOP / BLOCK_SIZE; - int vphysical_blocks[VBLOCKS]; - const int warp_start_block_idx = warp_start_token_idx / BLOCK_SIZE; - #pragma unroll for (int b = 0; b < VBLOCKS; b++) { const int vblock_idx = warp_start_block_idx + b; const int vblock_idx_ctx = @@ -305,12 +952,13 @@ __global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_kernel( vphysical_blocks[b] = block_table[vblock_idx_ctx]; } - // each 4 lanes fetch 8 helems, so warp fetches 8*16 = 128 helems + // fetch q elements + // every 4 lanes fetch 8 elems, so warp fetches 8*16 = 128 elems const scalar_t* q_ptr = q + seq_idx * q_stride + wg_start_head_idx * HEAD_SIZE; const _B16x8* q_ptrh8 = reinterpret_cast(q_ptr); const int qhead_elemh8 = laneid / 4; - #pragma unroll + for (int h = 0; h < QHLOOP - 1; h++) { const int qhead_idx = h * 4 + lane4id; Qlocal[h] = q_ptrh8[qhead_idx * HEAD_SIZE / 8 + qhead_elemh8]; @@ -324,22 +972,24 @@ __global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_kernel( Qlocal[QHLOOP - 1].xy[1] = {0}; } + // fetch k elements const cache_t* k_ptr = k_cache + physical_block_number * kv_block_stride + wg_start_kv_head_idx * kv_head_stride; - const int physical_block_offset = - local_token_idx % BLOCK_SIZE; // since x=half8, physical_block_offset - // is already cast as _H8 + // physical_block_offset is already cast in terms of _B16x8 + const int physical_block_offset = local_token_idx % BLOCK_SIZE; + + // each K fetch is for 8 elements of cache_t which are later dequantized to + // scalar_t for fp8 if constexpr (KV_DTYPE == vllm::Fp8KVCacheDataType::kAuto) { const _B16x8* k_ptrh8 = reinterpret_cast(k_ptr); - #pragma unroll for (int d = 0; d < KHELOOP; d++) { Klocal[d] = k_ptrh8[d * BLOCK_SIZE + physical_block_offset]; } } else { + // vllm defines X as 16 Bytes of elements of cache_t constexpr int X = 16 / sizeof(cache_t); const cache_t* k_ptr2 = k_ptr + physical_block_offset * X; - #pragma unroll for (int d = 0; d < KHELOOP; d++) { const int head_elem = d * 8; const int offset1 = head_elem / X; @@ -349,9 +999,9 @@ __global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_kernel( } } + // optional alibi fetch float alibi_slope[QHLOOP]; - if (alibi_slopes != nullptr) { - #pragma unroll + if constexpr (ALIBI_ENABLED) { for (int h = 0; h < QHLOOP; h++) { const int qhead_idx = h * 4 + lane4id; alibi_slope[h] = (qhead_idx < GQA_RATIO) @@ -361,10 +1011,10 @@ __global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_kernel( } const cache_t* v_ptr = v_cache + wg_start_kv_head_idx * kv_head_stride; + // fetch vcache in kv cache auto case if constexpr (KV_DTYPE == vllm::Fp8KVCacheDataType::kAuto) { const _B16x8* v_ptrh8 = reinterpret_cast(v_ptr); // iterate over each v block - #pragma unroll for (int b = 0; b < VBLOCKS; b++) { // int32 physical_block_number leads to overflow when multiplied with // kv_block_stride @@ -373,21 +1023,20 @@ __global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_kernel( const _B16x8* v_ptrh8b = v_ptrh8 + (vphysical_block_number * kv_block_stride) / 8; // iterate over each head elem (within head_size) - #pragma unroll for (int h = 0; h < VHELOOP; h++) { const int head_size_elem = h * WARP_SIZE + laneid; const _B16x8* v_ptrh8be = v_ptrh8b + head_size_elem * BLOCK_SIZE / 8; // iterate over all velems within block - #pragma unroll for (int d = 0; d < BLOCK_SIZE / 8; d++) { Vlocal[h][b * BLOCK_SIZE / 8 + d] = v_ptrh8be[d]; } } } - } else { + } // if constexpr (KV_DTYPE == vllm::Fp8KVCacheDataType::kAuto) + // fetch vcache in fp8 case + else { // if constexpr (KV_DTYPE != vllm::Fp8KVCacheDataType::kAuto) const _B8x8* v_ptrh8 = reinterpret_cast(v_ptr); // iterate over each v block - #pragma unroll for (int b = 0; b < VBLOCKS; b++) { // int32 physical_block_number leads to overflow when multiplied with // kv_block_stride @@ -396,164 +1045,153 @@ __global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_kernel( const _B8x8* v_ptrh8b = v_ptrh8 + (vphysical_block_number * kv_block_stride) / 8; // iterate over each head elem (within head_size) - #pragma unroll for (int h = 0; h < VHELOOP; h++) { const int head_size_elem = h * WARP_SIZE + laneid; const _B8x8* v_ptrh8be = v_ptrh8b + head_size_elem * BLOCK_SIZE / 8; // iterate over all velems within block - #pragma unroll for (int d = 0; d < BLOCK_SIZE / 8; d++) { - // Vlocalb8[h][b * BLOCK_SIZE / 8 + d] = v_ptrh8be[d]; - const _B8x8 Vlocalb8 = v_ptrh8be[d]; - Vlocal[h][b * BLOCK_SIZE / 8 + d] = - scaled_convert_b8x8(Vlocalb8, *v_scale_ptr); + Vlocalb8[h][b * BLOCK_SIZE / 8 + d] = v_ptrh8be[d]; } } } } + #define QK_mfma(x) \ + if constexpr (KV_DTYPE != vllm::Fp8KVCacheDataType::kAuto) { \ + Klocal[x] = convert_b8x8_custom(Klocalb8[x]); \ + } \ + for (int h = 0; h < QHLOOP; h++) { \ + dout[h] = gcn_mfma4x4x4_instr( \ + Qlocal[h].xy[0], Klocal[x].xy[0], dout[h]); \ + dout[h] = gcn_mfma4x4x4_instr( \ + Qlocal[h].xy[1], Klocal[x].xy[1], dout[h]); \ + } + // QK mfma with Q mfma block broadcast + // Q values across head_size dimension stored across lanes + // K values across head_size dimension are stored depthwise within lane + // Q broadcast with absz, cbid of mfma instruction + QK_mfma(0); + QK_mfma(1); + QK_mfma(2); + QK_mfma(3); + QK_mfma(4); + QK_mfma(5); + QK_mfma(6); + QK_mfma(7); + // below only needed for head size 128 + if constexpr (KHELOOP > 8) { + QK_mfma(8); + QK_mfma(9); + QK_mfma(10); + QK_mfma(11); + QK_mfma(12); + QK_mfma(13); + QK_mfma(14); + QK_mfma(15); + } + #undef QK_mfma + + float scale2 = scale; if constexpr (KV_DTYPE != vllm::Fp8KVCacheDataType::kAuto) { - #pragma unroll - for (int d = 0; d < KHELOOP; d++) { - Klocal[d] = - scaled_convert_b8x8(Klocalb8[d], *k_scale_ptr); - } + // post mfma scaling for fp8 + scale2 *= *k_scale_ptr; } - #pragma unroll for (int h = 0; h < QHLOOP; h++) { - dout[h] = gcn_mfma_instr(Qlocal[h].xy[0], - Klocal[0].xy[0], dout[h]); - dout[h] = gcn_mfma_instr(Qlocal[h].xy[1], - Klocal[0].xy[1], dout[h]); - dout[h] = gcn_mfma_instr(Qlocal[h].xy[0], - Klocal[1].xy[0], dout[h]); - dout[h] = gcn_mfma_instr(Qlocal[h].xy[1], - Klocal[1].xy[1], dout[h]); - dout[h] = gcn_mfma_instr(Qlocal[h].xy[0], - Klocal[2].xy[0], dout[h]); - dout[h] = gcn_mfma_instr(Qlocal[h].xy[1], - Klocal[2].xy[1], dout[h]); - dout[h] = gcn_mfma_instr(Qlocal[h].xy[0], - Klocal[3].xy[0], dout[h]); - dout[h] = gcn_mfma_instr(Qlocal[h].xy[1], - Klocal[3].xy[1], dout[h]); - dout[h] = gcn_mfma_instr(Qlocal[h].xy[0], - Klocal[4].xy[0], dout[h]); - dout[h] = gcn_mfma_instr(Qlocal[h].xy[1], - Klocal[4].xy[1], dout[h]); - dout[h] = gcn_mfma_instr(Qlocal[h].xy[0], - Klocal[5].xy[0], dout[h]); - dout[h] = gcn_mfma_instr(Qlocal[h].xy[1], - Klocal[5].xy[1], dout[h]); - dout[h] = gcn_mfma_instr(Qlocal[h].xy[0], - Klocal[6].xy[0], dout[h]); - dout[h] = gcn_mfma_instr(Qlocal[h].xy[1], - Klocal[6].xy[1], dout[h]); - dout[h] = gcn_mfma_instr(Qlocal[h].xy[0], - Klocal[7].xy[0], dout[h]); - dout[h] = gcn_mfma_instr(Qlocal[h].xy[1], - Klocal[7].xy[1], dout[h]); - if constexpr (KHELOOP > 8) { - dout[h] = gcn_mfma_instr(Qlocal[h].xy[0], - Klocal[8].xy[0], dout[h]); - dout[h] = gcn_mfma_instr(Qlocal[h].xy[1], - Klocal[8].xy[1], dout[h]); - dout[h] = gcn_mfma_instr(Qlocal[h].xy[0], - Klocal[9].xy[0], dout[h]); - dout[h] = gcn_mfma_instr(Qlocal[h].xy[1], - Klocal[9].xy[1], dout[h]); - dout[h] = gcn_mfma_instr(Qlocal[h].xy[0], - Klocal[10].xy[0], dout[h]); - dout[h] = gcn_mfma_instr(Qlocal[h].xy[1], - Klocal[10].xy[1], dout[h]); - dout[h] = gcn_mfma_instr(Qlocal[h].xy[0], - Klocal[11].xy[0], dout[h]); - dout[h] = gcn_mfma_instr(Qlocal[h].xy[1], - Klocal[11].xy[1], dout[h]); - dout[h] = gcn_mfma_instr(Qlocal[h].xy[0], - Klocal[12].xy[0], dout[h]); - dout[h] = gcn_mfma_instr(Qlocal[h].xy[1], - Klocal[12].xy[1], dout[h]); - dout[h] = gcn_mfma_instr(Qlocal[h].xy[0], - Klocal[13].xy[0], dout[h]); - dout[h] = gcn_mfma_instr(Qlocal[h].xy[1], - Klocal[13].xy[1], dout[h]); - dout[h] = gcn_mfma_instr(Qlocal[h].xy[0], - Klocal[14].xy[0], dout[h]); - dout[h] = gcn_mfma_instr(Qlocal[h].xy[1], - Klocal[14].xy[1], dout[h]); - dout[h] = gcn_mfma_instr(Qlocal[h].xy[0], - Klocal[15].xy[0], dout[h]); - dout[h] = gcn_mfma_instr(Qlocal[h].xy[1], - Klocal[15].xy[1], dout[h]); - } // KHELOOP>8 - dout[h] *= scale; + dout[h] *= scale2; } - // transpose dout so that 4 token ids are in each lane, and 4 heads are across - // 4 lanes - #pragma unroll + + // transpose dout so that 4 token ids are in each lane, and 4 heads are + // across 4 lanes for (int h = 0; h < QHLOOP; h++) { floatx4 tmp = {0}; - #pragma unroll for (int i = 0; i < 4; i++) { const float B = (lane4id == i) ? 1.0f : 0.0f; - // const float A = (global_token_idx < context_len) ? dout[h][i] : 0.0f; tmp = __builtin_amdgcn_mfma_f32_4x4x1f32(dout[h][i], B, tmp, 0, 0, 0); - // tmp = __builtin_amdgcn_mfma_f32_4x4x1f32(A, B, tmp, 0, 0, 0); } dout[h] = tmp; } const int lane4_token_idx = 4 * (global_token_idx >> 2); - const int alibi_offset = lane4_token_idx - context_len + 1; - if (alibi_slopes != nullptr) { - #pragma unroll + + if constexpr (ALIBI_ENABLED) { + const int alibi_offset = lane4_token_idx - context_len + 1; for (int h = 0; h < QHLOOP; h++) { - #pragma unroll for (int i = 0; i < 4; i++) { dout[h][i] += alibi_slope[h] * (alibi_offset + i); } } } - #pragma unroll + const int bpermute_mask = 4 * (16 * ((laneid >> 2) % 4) + lane4id); + for (int h = 0; h < QHLOOP; h++) { qk_max[h] = -FLT_MAX; - #pragma unroll for (int i = 0; i < 4; i++) { qk_max[h] = (lane4_token_idx + i < context_len) ? fmaxf(qk_max[h], dout[h][i]) : qk_max[h]; } - #pragma unroll - for (int mask = WARP_SIZE / 2; mask >= 4; mask /= 2) { - qk_max[h] = fmaxf(qk_max[h], __shfl_xor(qk_max[h], mask)); - } + + // for (int mask = WARP_SIZE / 2; mask >= 4; mask /= 2) { + // qk_max[h] = fmaxf(qk_max[h], __shfl_xor(qk_max[h], mask)); + // } + // faster version of above code with dpp + asm("v_nop\n v_nop\n v_max_f32_dpp %0, %1, %2 row_ror:4" + : "=v"(qk_max[h]) + : "v"(qk_max[h]), "v"(qk_max[h])); + asm("v_nop\n v_nop\n v_max_f32_dpp %0, %1, %2 row_ror:8" + : "=v"(qk_max[h]) + : "v"(qk_max[h]), "v"(qk_max[h])); + + auto tmp = __builtin_amdgcn_ds_bpermute( + bpermute_mask, *reinterpret_cast(&qk_max[h])); + qk_max[h] = *reinterpret_cast(&tmp); + asm("v_nop\n v_nop\n v_max_f32_dpp %0, %1, %2 row_ror:4" + : "=v"(qk_max[h]) + : "v"(qk_max[h]), "v"(qk_max[h])); + asm("v_nop\n v_nop\n v_max_f32_dpp %0, %1, %2 row_ror:8" + : "=v"(qk_max[h]) + : "v"(qk_max[h]), "v"(qk_max[h])); } float exp_sum[QHLOOP]; - #pragma unroll for (int h = 0; h < QHLOOP; h++) { exp_sum[h] = 0.0f; - #pragma unroll for (int i = 0; i < 4; i++) { dout[h][i] = (lane4_token_idx + i < context_len) ? __expf(dout[h][i] - qk_max[h]) : 0.0f; exp_sum[h] += dout[h][i]; } - #pragma unroll - for (int mask = WARP_SIZE / 2; mask >= 4; mask /= 2) { - exp_sum[h] += __shfl_xor(exp_sum[h], mask); - } + // for (int mask = WARP_SIZE / 2; mask >= 4; mask /= 2) { + // exp_sum[h] += __shfl_xor(exp_sum[h], mask); + // } + // faster version of above code with dpp + asm("v_nop\n v_nop\n v_add_f32_dpp %0, %1, %2 row_ror:4" + : "=v"(exp_sum[h]) + : "v"(exp_sum[h]), "v"(exp_sum[h])); + asm("v_nop\n v_nop\n v_add_f32_dpp %0, %1, %2 row_ror:8" + : "=v"(exp_sum[h]) + : "v"(exp_sum[h]), "v"(exp_sum[h])); + + auto tmp = __builtin_amdgcn_ds_bpermute( + bpermute_mask, *reinterpret_cast(&exp_sum[h])); + exp_sum[h] = *reinterpret_cast(&tmp); + asm("v_nop\n v_nop\n v_add_f32_dpp %0, %1, %2 row_ror:4" + : "=v"(exp_sum[h]) + : "v"(exp_sum[h]), "v"(exp_sum[h])); + asm("v_nop\n v_nop\n v_add_f32_dpp %0, %1, %2 row_ror:8" + : "=v"(exp_sum[h]) + : "v"(exp_sum[h]), "v"(exp_sum[h])); } - #pragma unroll - for (int h = 0; h < QHLOOP; h++) { - const int head_idx = 4 * h + lane4id; - shared_qk_max[warpid][head_idx] = qk_max[h]; - shared_exp_sum[warpid][head_idx] = exp_sum[h]; + if (laneid < 4) { + for (int h = 0; h < QHLOOP; h++) { + const int head_idx = 4 * h + lane4id; + shared_qk_max[warpid][head_idx] = qk_max[h]; + shared_exp_sum[warpid][head_idx] = exp_sum[h]; + } } } // warp within context @@ -564,18 +1202,16 @@ __global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_kernel( max_logits + seq_idx * num_heads * max_num_partitions + partition_idx; float* exp_sums_ptr = exp_sums + seq_idx * num_heads * max_num_partitions + partition_idx; - #pragma unroll + // calculate qk_max and exp_sums for partition for (int h = 0; h < QHLOOP; h++) { float global_qk_max = -FLT_MAX; float warp_qk_max[NWARPS]; const int head_idx = 4 * h + lane4id; - #pragma unroll for (int w = 0; w < NWARPS; w++) { warp_qk_max[w] = shared_qk_max[w][head_idx]; global_qk_max = fmaxf(global_qk_max, warp_qk_max[w]); } float global_exp_sum = 0.0f; - #pragma unroll for (int w = 0; w < NWARPS; w++) { global_exp_sum += shared_exp_sum[w][head_idx] * __expf(warp_qk_max[w] - global_qk_max); @@ -590,99 +1226,91 @@ __global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_kernel( __expf(qk_max[h] - global_qk_max); dout[h] *= global_inv_sum_scale; } + constexpr bool LOGITS_RTZ_CONVERSION = false; // logits[h] -> every 4 lanes hold 4 heads, each lane holds 4 tokens, there // are 4x16 tokens across warp _B16x4 logits[QHLOOP]; - #pragma unroll for (int h = 0; h < QHLOOP; h++) { - logits[h] = from_floatx4(dout[h]); + if constexpr (LOGITS_RTZ_CONVERSION) { + // use rtz for faster performance with no perceivable accuracy loss + logits[h] = from_floatx4_rtz(dout[h]); + } else { + logits[h] = from_floatx4(dout[h]); + } } - __shared__ _B16x4 vout_shared[QHLOOP][VHELOOP][WARP_SIZE][NWARPS + 1]; - if (warp_start_token_idx >= context_len) { // warp out of context - #pragma unroll for (int qh = 0; qh < QHLOOP; qh++) { - #pragma unroll for (int vh = 0; vh < VHELOOP; vh++) { vout_shared[qh][vh][laneid][warpid] = {0}; } } } else { // warp in context - // iterate across heads - #pragma unroll - for (int qh = 0; qh < QHLOOP; qh++) { - // iterate over each v head elem (within head_size) - #pragma unroll - for (int vh = 0; vh < VHELOOP; vh++) { - floatx4 acc = {0}; - // iterate over tokens - acc = gcn_mfma_instr(logits[qh], Vlocal[vh][0].xy[0], - acc); - acc = gcn_mfma_instr(logits[qh], Vlocal[vh][0].xy[1], - acc); - acc = gcn_mfma_instr(logits[qh], Vlocal[vh][1].xy[0], - acc); - acc = gcn_mfma_instr(logits[qh], Vlocal[vh][1].xy[1], - acc); - acc = gcn_mfma_instr(logits[qh], Vlocal[vh][2].xy[0], - acc); - acc = gcn_mfma_instr(logits[qh], Vlocal[vh][2].xy[1], - acc); - acc = gcn_mfma_instr(logits[qh], Vlocal[vh][3].xy[0], - acc); - acc = gcn_mfma_instr(logits[qh], Vlocal[vh][3].xy[1], - acc); - acc = gcn_mfma_instr(logits[qh], Vlocal[vh][4].xy[0], - acc); - acc = gcn_mfma_instr(logits[qh], Vlocal[vh][4].xy[1], - acc); - acc = gcn_mfma_instr(logits[qh], - Vlocal[vh][5].xy[0], acc); - acc = gcn_mfma_instr(logits[qh], - Vlocal[vh][5].xy[1], acc); - acc = gcn_mfma_instr(logits[qh], - Vlocal[vh][6].xy[0], acc); - acc = gcn_mfma_instr(logits[qh], - Vlocal[vh][6].xy[1], acc); - acc = gcn_mfma_instr(logits[qh], - Vlocal[vh][7].xy[0], acc); - acc = gcn_mfma_instr(logits[qh], - Vlocal[vh][7].xy[1], acc); - vout_shared[qh][vh][laneid][warpid] = from_floatx4(acc); + #define SV_mfma(x) \ + if constexpr (KV_DTYPE != vllm::Fp8KVCacheDataType::kAuto) { \ + Vlocal[vh][x] = convert_b8x8_custom(Vlocalb8[vh][x]); \ + } \ + for (int qh = 0; qh < QHLOOP; qh++) { \ + acc[qh] = gcn_mfma4x4x4_instr( \ + logits[qh], Vlocal[vh][x].xy[0], acc[qh]); \ + acc[qh] = gcn_mfma4x4x4_instr( \ + logits[qh], Vlocal[vh][x].xy[1], acc[qh]); \ + } + + for (int vh = 0; vh < VHELOOP; vh++) { + floatx4 acc[QHLOOP]; + for (int qh = 0; qh < QHLOOP; qh++) { + acc[qh] = {0}; + } + // SoftMax-V calculation + // logits -> token dimension is distributed across lanes + // Vlocal -> token dimension is depthwise within lane + // uses mfma instruction block broadcast for logits + SV_mfma(0); + SV_mfma(1); + SV_mfma(2); + SV_mfma(3); + SV_mfma(4); + SV_mfma(5); + SV_mfma(6); + SV_mfma(7); + + for (int qh = 0; qh < QHLOOP; qh++) { + if constexpr (KV_DTYPE != vllm::Fp8KVCacheDataType::kAuto) { + // post mfma v scale for fp8 + acc[qh] *= *v_scale_ptr; + } + vout_shared[qh][vh][laneid][warpid] = from_floatx4(acc[qh]); } } + + #undef SV_mfma } // warp in context __syncthreads(); + // final write to tmp_out after vout accumulation if (warpid == 0) { _B16x4 vout[QHLOOP][VHELOOP]; // iterate across heads - scalar_t* out_ptr; - int out_num_partitions; - if (context_len > partition_size) { - out_num_partitions = max_num_partitions; - out_ptr = out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE + - partition_idx * HEAD_SIZE; - } else { - out_num_partitions = 1; - out_ptr = final_out + seq_idx * num_heads * HEAD_SIZE; - } - #pragma unroll for (int qh = 0; qh < QHLOOP; qh++) { - // iterate over each v head elem (within head_size) - #pragma unroll + // iterate over each v head elem (within head_size) for (int vh = 0; vh < VHELOOP; vh++) { vout[qh][vh] = {0}; - #pragma unroll for (int w = 0; w < NWARPS; w++) { vout[qh][vh] = addx4(vout[qh][vh], vout_shared[qh][vh][laneid][w]); } + } + } + scalar_t* out_ptr = out + + seq_idx * num_heads * max_num_partitions * HEAD_SIZE + + partition_idx * HEAD_SIZE; + const int out_num_partitions = max_num_partitions; + bit16_t* out_ptr_b16 = reinterpret_cast(out_ptr); + for (int qh = 0; qh < QHLOOP; qh++) { + for (int vh = 0; vh < VHELOOP; vh++) { const int head_size_elem = vh * WARP_SIZE + laneid; - bit16_t* out_ptr_b16 = reinterpret_cast(out_ptr); - #pragma unroll for (int i = 0; i < 4; i++) { const int head_idx = 4 * qh + i; if (head_idx < GQA_RATIO) { @@ -693,15 +1321,15 @@ __global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_kernel( } } } - } + } // warpid == 0 } // Grid: (num_heads, num_seqs). -template +template __global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_reduce_kernel( - scalar_t* __restrict__ out, // [num_seqs, num_heads, head_size] + OUTT* __restrict__ out, // [num_seqs, num_heads, head_size] const float* __restrict__ exp_sums, // [num_seqs, num_heads, // max_num_partitions] const float* __restrict__ max_logits, // [num_seqs, num_heads, @@ -715,18 +1343,13 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_reduce_kernel( const int seq_idx = blockIdx.y; const int context_len = context_lens[seq_idx]; const int num_partitions = DIVIDE_ROUND_UP(context_len, PARTITION_SIZE); - if (num_partitions == 1) { - // if num_partitions==1, main kernel will write to out directly, no work in - // reduction kernel - return; - } - constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE; const int warpid = threadIdx.x / WARP_SIZE; const int laneid = threadIdx.x % WARP_SIZE; __shared__ float shared_global_exp_sum; - __shared__ float shared_exp_sums[2 * WARP_SIZE]; + // max num partitions supported is warp_size * NPAR_LOOPS + __shared__ float shared_exp_sums[NPAR_LOOPS * WARP_SIZE]; if (warpid == 0) { const float* max_logits_ptr = max_logits + @@ -735,14 +1358,25 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_reduce_kernel( // valid partition is the last valid partition in case threadid > num // partitions - const int valid_partition = - (threadIdx.x < num_partitions) ? threadIdx.x : num_partitions - 1; - const int valid_partition2 = (WARP_SIZE + threadIdx.x < num_partitions) - ? WARP_SIZE + threadIdx.x - : num_partitions - 1; - float reg_max_logit = max_logits_ptr[valid_partition]; - float reg_max_logit2 = max_logits_ptr[valid_partition2]; - float max_logit = fmaxf(reg_max_logit, reg_max_logit2); + int valid_partition[NPAR_LOOPS]; + float reg_max_logit[NPAR_LOOPS]; + const int last_valid_partition = num_partitions - 1; + + #pragma unroll + for (int i = 0; i < NPAR_LOOPS; i++) { + const int partition_no = i * WARP_SIZE + threadIdx.x; + valid_partition[i] = + (partition_no < num_partitions) ? partition_no : last_valid_partition; + } + #pragma unroll + for (int i = 0; i < NPAR_LOOPS; i++) { + reg_max_logit[i] = max_logits_ptr[valid_partition[i]]; + } + float max_logit = reg_max_logit[0]; + #pragma unroll + for (int i = 1; i < NPAR_LOOPS; i++) { + max_logit = fmaxf(max_logit, reg_max_logit[i]); + } #pragma unroll for (int mask = WARP_SIZE / 2; mask >= 1; mask /= 2) { @@ -753,17 +1387,28 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_reduce_kernel( seq_idx * num_heads * max_num_partitions + head_idx * max_num_partitions; - float global_exp_sum = 0.0f; - float rescaled_exp_sum = exp_sums_ptr[valid_partition]; - float rescaled_exp_sum2 = exp_sums_ptr[valid_partition2]; - rescaled_exp_sum *= - (threadIdx.x < num_partitions) ? expf(reg_max_logit - max_logit) : 0.0f; - rescaled_exp_sum2 *= (threadIdx.x + WARP_SIZE < num_partitions) - ? expf(reg_max_logit2 - max_logit) - : 0.0f; - global_exp_sum += rescaled_exp_sum + rescaled_exp_sum2; - shared_exp_sums[threadIdx.x] = rescaled_exp_sum; - shared_exp_sums[threadIdx.x + WARP_SIZE] = rescaled_exp_sum2; + float rescaled_exp_sum[NPAR_LOOPS]; + #pragma unroll + for (int i = 0; i < NPAR_LOOPS; i++) { + rescaled_exp_sum[i] = exp_sums_ptr[valid_partition[i]]; + } + #pragma unroll + for (int i = 0; i < NPAR_LOOPS; i++) { + const int partition_no = i * WARP_SIZE + threadIdx.x; + rescaled_exp_sum[i] *= (partition_no < num_partitions) + ? expf(reg_max_logit[i] - max_logit) + : 0.0f; + } + float global_exp_sum = rescaled_exp_sum[0]; + #pragma unroll + for (int i = 1; i < NPAR_LOOPS; i++) { + global_exp_sum += rescaled_exp_sum[i]; + } + #pragma unroll + for (int i = 0; i < NPAR_LOOPS; i++) { + const int partition_no = i * WARP_SIZE + threadIdx.x; + shared_exp_sums[partition_no] = rescaled_exp_sum[i]; + } #pragma unroll for (int mask = WARP_SIZE / 2; mask >= 1; mask /= 2) { @@ -840,39 +1485,72 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_reduce_kernel( } } - if (num_partitions > MAX_NPAR) { - idx = 0; + for (int p = 1; p < NPAR_LOOPS; p++) { + if (num_partitions > p * MAX_NPAR) { + idx = 0; #pragma unroll - for (int j = MAX_NPAR * HEAD_SIZE; j < 2 * MAX_NPAR * HEAD_SIZE; - j += HEAD_SIZE) { - // lastj is last valid partition - const int lastj_offset = - (j < num_partition_offset) ? j : last_partition_offset; - tmps[idx] = tmp_out_ptr[lastj_offset]; - idx++; - } + for (int j = p * MAX_NPAR * HEAD_SIZE; j < (p + 1) * MAX_NPAR * HEAD_SIZE; + j += HEAD_SIZE) { + // lastj is last valid partition + const int lastj_offset = + (j < num_partition_offset) ? j : last_partition_offset; + tmps[idx] = tmp_out_ptr[lastj_offset]; + idx++; + } #pragma unroll - for (int j = 0; j < MAX_NPAR; j++) { - acc += to_float(tmps[j]) * shared_exp_sums[j + MAX_NPAR]; + for (int j = 0; j < MAX_NPAR; j++) { + acc += to_float(tmps[j]) * shared_exp_sums[j + p * MAX_NPAR]; + } } } const float inv_global_exp_sum = __fdividef(1.0f, shared_global_exp_sum + 1e-6f); acc *= inv_global_exp_sum; - scalar_t* out_ptr = - out + seq_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE; - out_ptr[threadIdx.x] = from_float(acc); + OUTT* out_ptr = out + seq_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE; + if constexpr (std::is_same::value) { + out_ptr[threadIdx.x] = hip_fp8(acc).data; + } else { + out_ptr[threadIdx.x] = from_float(acc); + } } #else // !defined(__HIP__MI300_MI250__) TODO: Add NAVI support template +__global__ +__launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_mfma16_kernel( + const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size] + const cache_t* __restrict__ k_cache, // [num_blocks, num_kv_heads, + // head_size/x, block_size, x] + const cache_t* __restrict__ v_cache, // [num_blocks, num_kv_heads, + // head_size, block_size] + const int num_kv_heads, const float scale, + const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq] + const int* __restrict__ context_lens, // [num_seqs] + const int max_num_blocks_per_seq, + const float* __restrict__ alibi_slopes, // [num_heads] + const int q_stride, const int kv_block_stride, const int kv_head_stride, + float* __restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions] + float* __restrict__ max_logits, // [num_seqs, num_heads, + // max_num_partitions] + scalar_t* __restrict__ out, // [num_seqs, num_heads, max_num_partitions, + // head_size] + OUTT* __restrict__ final_out, // [num_seqs, num_heads, head_size] + int max_ctx_blocks, const float* k_scale_ptr, const float* v_scale_ptr) { + UNREACHABLE_CODE +} + +template -__global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_kernel( +__global__ +__launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_mfma4_kernel( const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size] const cache_t* __restrict__ k_cache, // [num_blocks, num_kv_heads, // head_size/x, block_size, x] @@ -887,19 +1565,19 @@ __global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_kernel( float* __restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions] float* __restrict__ max_logits, // [num_seqs, num_heads, // max_num_partitions] - scalar_t* __restrict__ out, // [num_seqs, num_heads, max_num_partitions, - // head_size] - scalar_t* __restrict__ final_out, // [num_seqs, num_heads, head_size] + scalar_t* __restrict__ out, // [num_seqs, num_heads, max_num_partitions, + // head_size] + OUTT* __restrict__ final_out, // [num_seqs, num_heads, head_size] int max_ctx_blocks, const float* k_scale, const float* v_scale) { UNREACHABLE_CODE } // Grid: (num_heads, num_seqs). -template +template __global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_reduce_kernel( - scalar_t* __restrict__ out, // [num_seqs, num_heads, head_size] + OUTT* __restrict__ out, // [num_seqs, num_heads, head_size] const float* __restrict__ exp_sums, // [num_seqs, num_heads, // max_num_partitions] const float* __restrict__ max_logits, // [num_seqs, num_heads, @@ -913,9 +1591,10 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_reduce_kernel( #endif // defined(__HIP__MI300_MI250__) TODO: Add NAVI support -#define LAUNCH_CUSTOM_ATTENTION(GQA_RATIO) \ - paged_attention_ll4mi_QKV_kernel \ +#define LAUNCH_CUSTOM_ATTENTION_MFMA16(GQA_RATIO) \ + paged_attention_ll4mi_QKV_mfma16_kernel \ <<>>( \ query_ptr, key_cache_ptr, value_cache_ptr, num_kv_heads, scale, \ block_tables_ptr, context_lens_ptr, max_num_blocks_per_seq, \ @@ -923,8 +1602,27 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_reduce_kernel( exp_sums_ptr, max_logits_ptr, tmp_out_ptr, out_ptr, max_ctx_blocks, \ k_scale_ptr, v_scale_ptr); +#define LAUNCH_CUSTOM_ATTENTION_MFMA4(GQA_RATIO) \ + paged_attention_ll4mi_QKV_mfma4_kernel \ + <<>>( \ + query_ptr, key_cache_ptr, value_cache_ptr, num_kv_heads, scale, \ + block_tables_ptr, context_lens_ptr, max_num_blocks_per_seq, \ + alibi_slopes_ptr, q_stride, kv_block_stride, kv_head_stride, \ + exp_sums_ptr, max_logits_ptr, tmp_out_ptr, out_ptr, max_ctx_blocks, \ + k_scale_ptr, v_scale_ptr); + +#define LAUNCH_CUSTOM_REDUCTION(NPAR_LOOPS) \ + paged_attention_ll4mi_reduce_kernel \ + <<>>( \ + out_ptr, exp_sums_ptr, max_logits_ptr, tmp_out_ptr, \ + context_lens_ptr, max_num_partitions); + template + int BLOCK_SIZE, int HEAD_SIZE, typename OUTT, int PARTITION_SIZE_OLD, + bool ALIBI_ENABLED> void paged_attention_custom_launcher( torch::Tensor& out, torch::Tensor& exp_sums, torch::Tensor& max_logits, torch::Tensor& tmp_out, torch::Tensor& query, torch::Tensor& key_cache, @@ -946,7 +1644,6 @@ void paged_attention_custom_launcher( ? reinterpret_cast(alibi_slopes.value().data_ptr()) : nullptr; - T* out_ptr = reinterpret_cast(out.data_ptr()); float* exp_sums_ptr = reinterpret_cast(exp_sums.data_ptr()); float* max_logits_ptr = reinterpret_cast(max_logits.data_ptr()); T* tmp_out_ptr = reinterpret_cast(tmp_out.data_ptr()); @@ -955,107 +1652,163 @@ void paged_attention_custom_launcher( KVT* value_cache_ptr = reinterpret_cast(value_cache.data_ptr()); int* block_tables_ptr = block_tables.data_ptr(); int* context_lens_ptr = context_lens.data_ptr(); + const float* k_scale_ptr = reinterpret_cast(k_scale.data_ptr()); const float* v_scale_ptr = reinterpret_cast(v_scale.data_ptr()); + OUTT* out_ptr = reinterpret_cast(out.data_ptr()); const int max_ctx_blocks = DIVIDE_ROUND_UP(max_context_len, BLOCK_SIZE); + + // partition size is fixed at 256 since both mfma4 and mfma16 kernels support + // it mfma4 kernel also supports partition size 512 + constexpr int PARTITION_SIZE = 256; const int max_num_partitions = DIVIDE_ROUND_UP(max_context_len, PARTITION_SIZE); const int gqa_ratio = num_heads / num_kv_heads; assert(num_heads % num_kv_heads == 0); assert(head_size == HEAD_SIZE); - assert(max_num_partitions <= 128); - constexpr int NTHR = PARTITION_SIZE; + constexpr int NTHR = 256; dim3 grid(num_seqs, max_num_partitions, num_kv_heads); dim3 block(NTHR); const at::cuda::OptionalCUDAGuard device_guard(device_of(query)); const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + // mfma4 kernel is faster than mfma16 for gqa_ratio <= 4 switch (gqa_ratio) { case 1: - LAUNCH_CUSTOM_ATTENTION(1); + LAUNCH_CUSTOM_ATTENTION_MFMA4(1); break; case 2: - LAUNCH_CUSTOM_ATTENTION(2); + LAUNCH_CUSTOM_ATTENTION_MFMA4(2); break; case 3: - LAUNCH_CUSTOM_ATTENTION(3); + LAUNCH_CUSTOM_ATTENTION_MFMA4(3); break; case 4: - LAUNCH_CUSTOM_ATTENTION(4); + LAUNCH_CUSTOM_ATTENTION_MFMA4(4); break; case 5: - LAUNCH_CUSTOM_ATTENTION(5); + LAUNCH_CUSTOM_ATTENTION_MFMA16(5); break; case 6: - LAUNCH_CUSTOM_ATTENTION(6); + LAUNCH_CUSTOM_ATTENTION_MFMA16(6); break; case 7: - LAUNCH_CUSTOM_ATTENTION(7); + LAUNCH_CUSTOM_ATTENTION_MFMA16(7); break; case 8: - LAUNCH_CUSTOM_ATTENTION(8); + LAUNCH_CUSTOM_ATTENTION_MFMA16(8); break; case 9: - LAUNCH_CUSTOM_ATTENTION(9); + LAUNCH_CUSTOM_ATTENTION_MFMA16(9); break; case 10: - LAUNCH_CUSTOM_ATTENTION(10); + LAUNCH_CUSTOM_ATTENTION_MFMA16(10); break; case 11: - LAUNCH_CUSTOM_ATTENTION(11); + LAUNCH_CUSTOM_ATTENTION_MFMA16(11); break; case 12: - LAUNCH_CUSTOM_ATTENTION(12); + LAUNCH_CUSTOM_ATTENTION_MFMA16(12); break; case 13: - LAUNCH_CUSTOM_ATTENTION(13); + LAUNCH_CUSTOM_ATTENTION_MFMA16(13); break; case 14: - LAUNCH_CUSTOM_ATTENTION(14); + LAUNCH_CUSTOM_ATTENTION_MFMA16(14); break; case 15: - LAUNCH_CUSTOM_ATTENTION(15); + LAUNCH_CUSTOM_ATTENTION_MFMA16(15); break; case 16: - LAUNCH_CUSTOM_ATTENTION(16); + LAUNCH_CUSTOM_ATTENTION_MFMA16(16); break; default: TORCH_CHECK(false, "Unsupported gqa ratio: ", gqa_ratio); break; } - // dim3 grid2(num_heads,num_seqs,head_size/HEAD_ELEMS_PER_WG); - // dim3 block2(1024); - // LAUNCH_CUSTOM_ATTENTION2; - - // reduction kernel is only required if max_context_len > partition size, - // otherwise main kernel writes directly to final output - // note there are cases with graphing where max_context_len is the max - // supported by graphing, not the actual max among all the sequences: in that - // case reduction kernel will still run but return immediately - if (max_context_len > PARTITION_SIZE) { - dim3 reduce_grid(num_heads, num_seqs); - dim3 reduce_block(head_size); - paged_attention_ll4mi_reduce_kernel - <<>>( - out_ptr, exp_sums_ptr, max_logits_ptr, tmp_out_ptr, - context_lens_ptr, max_num_partitions); + + dim3 reduce_grid(num_heads, num_seqs); + dim3 reduce_block(head_size); + const int npar_loops = DIVIDE_ROUND_UP(max_num_partitions, WARP_SIZE); + // reduction kernel supports upto 8 NPAR_loops * 64 (warp_size) * 256 + // (partition size) = 128K context length + switch (npar_loops) { + case 1: + LAUNCH_CUSTOM_REDUCTION(1); + break; + case 2: + LAUNCH_CUSTOM_REDUCTION(2); + break; + case 3: + LAUNCH_CUSTOM_REDUCTION(3); + break; + case 4: + LAUNCH_CUSTOM_REDUCTION(4); + break; + case 5: + LAUNCH_CUSTOM_REDUCTION(5); + break; + case 6: + LAUNCH_CUSTOM_REDUCTION(6); + break; + case 7: + LAUNCH_CUSTOM_REDUCTION(7); + break; + case 8: + LAUNCH_CUSTOM_REDUCTION(8); + break; + default: + TORCH_CHECK(false, "Unsupported npar_loops: ", npar_loops); + break; } } -#define CALL_CUSTOM_LAUNCHER(T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE) \ - paged_attention_custom_launcher( \ - out, exp_sums, max_logits, tmp_out, query, key_cache, value_cache, \ - num_kv_heads, scale, block_tables, context_lens, max_context_len, \ +#define CALL_CUSTOM_LAUNCHER(T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE, OUTT, \ + PSIZE, ALIBI_ENABLED) \ + paged_attention_custom_launcher( \ + out, exp_sums, max_logits, tmp_out, query, key_cache, value_cache, \ + num_kv_heads, scale, block_tables, context_lens, max_context_len, \ alibi_slopes, k_scale, v_scale); +#define CALL_CUSTOM_LAUNCHER_ALIBI(T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE, \ + OUTT, PSIZE) \ + if (alibi_slopes) { \ + CALL_CUSTOM_LAUNCHER(T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE, OUTT, PSIZE, \ + true); \ + } else { \ + CALL_CUSTOM_LAUNCHER(T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE, OUTT, PSIZE, \ + false); \ + } + +#define CALL_CUSTOM_LAUNCHER_PSIZE(T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE, \ + OUTT) \ + switch (partition_size) { \ + case 256: \ + CALL_CUSTOM_LAUNCHER_ALIBI(T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE, OUTT, \ + 256); \ + break; \ + default: \ + TORCH_CHECK(false, "Unsupported partition size: ", partition_size); \ + break; \ + } + +#if defined(__HIPCC__) && defined(__gfx90a__) + #define CALL_CUSTOM_LAUNCHER_OUT(T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE) \ + CALL_CUSTOM_LAUNCHER_PSIZE(T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE, T); +#else + #define CALL_CUSTOM_LAUNCHER_OUT(T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE) \ + CALL_CUSTOM_LAUNCHER_PSIZE(T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE, T); +#endif #define CALL_CUSTOM_LAUNCHER_BLK(T, KVT, KV_DTYPE, HEAD_SIZE) \ switch (block_size) { \ case 16: \ - CALL_CUSTOM_LAUNCHER(T, KVT, KV_DTYPE, 16, HEAD_SIZE); \ + CALL_CUSTOM_LAUNCHER_OUT(T, KVT, KV_DTYPE, 16, HEAD_SIZE); \ break; \ case 32: \ - CALL_CUSTOM_LAUNCHER(T, KVT, KV_DTYPE, 32, HEAD_SIZE); \ + CALL_CUSTOM_LAUNCHER_OUT(T, KVT, KV_DTYPE, 32, HEAD_SIZE); \ break; \ default: \ TORCH_CHECK(false, "Unsupported block size: ", block_size); \ @@ -1074,7 +1827,6 @@ void paged_attention_custom_launcher( TORCH_CHECK(false, "Unsupported head size: ", head_size); \ break; \ } - void paged_attention( torch::Tensor& out, // [num_seqs, num_heads, head_size] torch::Tensor& exp_sums, // [num_seqs, num_heads, max_num_partitions] @@ -1092,7 +1844,7 @@ void paged_attention( int64_t block_size, int64_t max_context_len, const std::optional& alibi_slopes, const std::string& kv_cache_dtype, torch::Tensor& k_scale, - torch::Tensor& v_scale) { + torch::Tensor& v_scale, int64_t partition_size) { const int head_size = query.size(2); if (kv_cache_dtype == "auto") { if (query.dtype() == at::ScalarType::Half) { @@ -1122,4 +1874,4 @@ void paged_attention( #undef WARP_SIZE #undef MAX #undef MIN -#undef DIVIDE_ROUND_UP \ No newline at end of file +#undef DIVIDE_ROUND_UP diff --git a/csrc/rocm/ops.h b/csrc/rocm/ops.h index ba161951772ad..23a0828fb37a0 100644 --- a/csrc/rocm/ops.h +++ b/csrc/rocm/ops.h @@ -11,4 +11,5 @@ void paged_attention(torch::Tensor& out, torch::Tensor& exp_sums, int64_t max_context_len, const std::optional& alibi_slopes, const std::string& kv_cache_dtype, torch::Tensor& k_scale, - torch::Tensor& v_scale); + torch::Tensor& v_scale, + int64_t partition_size); diff --git a/csrc/rocm/torch_bindings.cpp b/csrc/rocm/torch_bindings.cpp index a5d2e2f97a3ed..ad3431f349441 100644 --- a/csrc/rocm/torch_bindings.cpp +++ b/csrc/rocm/torch_bindings.cpp @@ -27,7 +27,8 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, rocm_ops) { " int max_context_len," " Tensor? alibi_slopes," " str kv_cache_dtype," - " Tensor k_scale, Tensor v_scale) -> ()"); + " Tensor k_scale, Tensor v_scale," + " int partition_size) -> ()"); rocm_ops.impl("paged_attention", torch::kCUDA, &paged_attention); } diff --git a/docs/source/features/quantization/fp8_e4m3_kvcache.md b/docs/source/features/quantization/fp8_e4m3_kvcache.md new file mode 100644 index 0000000000000..e2babefcc3d0d --- /dev/null +++ b/docs/source/features/quantization/fp8_e4m3_kvcache.md @@ -0,0 +1,41 @@ +(fp8-e4m3-kvcache)= + +# FP8 E4M3 KV Cache + +Quantizing the KV cache to FP8 reduces its memory footprint. This increases the number of tokens that can be stored in the cache, +improving throughput. OCP (Open Compute Project www.opencompute.org) specifies two common 8-bit floating point data formats: E5M2 +(5 exponent bits and 2 mantissa bits) and E4M3FN (4 exponent bits and 3 mantissa bits), often shortened as E4M3. One benefit of +the E4M3 format over E5M2 is that floating point numbers are represented in higher precision. However, the small dynamic range of +FP8 E4M3 (±240.0 can be represented) typically necessitates the use of a higher-precision (typically FP32) scaling factor alongside +each quantized tensor. For now, only per-tensor (scalar) scaling factors are supported. Development is ongoing to support scaling +factors of a finer granularity (e.g. per-channel). + +These scaling factors can be specified by passing an optional quantization param JSON to the LLM engine at load time. If +this JSON is not specified, scaling factors default to 1.0. These scaling factors are typically obtained when running an +unquantized model through a quantizer tool (e.g. AMD quantizer or NVIDIA AMMO). + +To install AMMO (AlgorithMic Model Optimization): + +```console +pip install --no-cache-dir --extra-index-url https://pypi.nvidia.com nvidia-ammo +``` + +Studies have shown that FP8 E4M3 quantization typically only minimally degrades inference accuracy. The most recent silicon +offerings e.g. AMD MI300, NVIDIA Hopper or later support native hardware conversion to and from fp32, fp16, bf16, etc. +Thus, LLM inference is greatly accelerated with minimal accuracy loss. + +Here is an example of how to enable this feature: + +```python +# two float8_e4m3fn kv cache scaling factor files are provided under tests/fp8_kv, please refer to +# https://github.com/vllm-project/vllm/blob/main/examples/other/fp8/README.md to generate kv_cache_scales.json of your own. + +from vllm import LLM, SamplingParams +sampling_params = SamplingParams(temperature=1.3, top_p=0.8) +llm = LLM(model="meta-llama/Llama-2-7b-chat-hf", + kv_cache_dtype="fp8", + calculate_kv_scales=True) +prompt = "London is the capital of" +out = llm.generate(prompt, sampling_params)[0].outputs[0].text +print(out) +``` diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index 85c1121ed6ff8..ef1eb097956d3 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -114,12 +114,15 @@ def paged_attention_rocm( kv_cache_dtype: str, k_scale: torch.Tensor, v_scale: torch.Tensor, + partition_size: int, ) -> None: torch.ops._rocm_C.paged_attention(out, exp_sum, max_logits, tmp_out, query, key_cache, value_cache, num_kv_heads, scale, block_tables, seq_lens, block_size, max_seq_len, alibi_slopes, - kv_cache_dtype, k_scale, v_scale) + kv_cache_dtype, k_scale, v_scale, + partition_size) + # pos encoding ops @@ -156,6 +159,18 @@ def fused_add_rms_norm(input: torch.Tensor, residual: torch.Tensor, torch.ops._C.fused_add_rms_norm(input, residual, weight, epsilon) +def scaled_rms_norm(out: torch.Tensor, input: torch.Tensor, + weight: torch.Tensor, scale: torch.Tensor, + epsilon: float) -> None: + torch.ops._C.rms_norm_static_fp8_quant(out, input, weight, scale, epsilon) + + +def scaled_fused_add_rms_norm(out: torch.Tensor, input: torch.Tensor, + residual: torch.Tensor, weight: torch.Tensor, + scale: torch.Tensor, epsilon: float) -> None: + torch.ops._C.fused_add_rms_norm_static_fp8_quant(out, input, residual, + weight, scale, epsilon) + def advance_step_flashattn(num_seqs: int, num_queries: int, block_size: int, input_tokens: torch.Tensor, sampled_token_ids: torch.Tensor, diff --git a/vllm/attention/backends/abstract.py b/vllm/attention/backends/abstract.py index 8027a52b82ffc..9f578b9bb1baf 100644 --- a/vllm/attention/backends/abstract.py +++ b/vllm/attention/backends/abstract.py @@ -232,8 +232,10 @@ class AttentionLayer(Protocol): _k_scale: torch.Tensor _v_scale: torch.Tensor - _k_scale_float: float - _v_scale_float: float + _k_scale_float: torch.Tensor + _v_scale_float: torch.Tensor + _q_scale: torch.Tensor + _prob_scale: torch.Tensor def forward( self, diff --git a/vllm/attention/backends/flash_attn.py b/vllm/attention/backends/flash_attn.py index 4a9aa1e217365..176152ab99cb8 100755 --- a/vllm/attention/backends/flash_attn.py +++ b/vllm/attention/backends/flash_attn.py @@ -671,6 +671,7 @@ def forward( kv_cache: torch.Tensor, attn_metadata: FlashAttentionMetadata, output: Optional[torch.Tensor] = None, + attn_type: AttentionType = AttentionType.DECODER, ) -> torch.Tensor: """Forward pass with FlashAttention. diff --git a/vllm/attention/backends/hpu_attn.py b/vllm/attention/backends/hpu_attn.py index 80c132c0a8c05..aa589dc51f958 100644 --- a/vllm/attention/backends/hpu_attn.py +++ b/vllm/attention/backends/hpu_attn.py @@ -160,6 +160,7 @@ def forward( kv_cache: torch.Tensor, attn_metadata: HPUAttentionMetadata, output: Optional[torch.Tensor] = None, + fp8_comp_scales: Optional[Tuple[torch.Tensor, ...]] = None, ) -> torch.Tensor: """Forward pass with xFormers and PagedAttention. diff --git a/vllm/attention/backends/ipex_attn.py b/vllm/attention/backends/ipex_attn.py index 57916a3c6a34c..0562c2817dc5b 100644 --- a/vllm/attention/backends/ipex_attn.py +++ b/vllm/attention/backends/ipex_attn.py @@ -179,6 +179,7 @@ def forward( kv_cache: torch.Tensor, attn_metadata: IpexAttnMetadata, # type: ignore output: Optional[torch.Tensor] = None, + attn_type: AttentionType = AttentionType.DECODER, ) -> torch.Tensor: """Forward pass with IPEX varlen_attention and PagedAttention. diff --git a/vllm/attention/backends/pallas.py b/vllm/attention/backends/pallas.py index 209a623ba441c..30d1f36d959d6 100644 --- a/vllm/attention/backends/pallas.py +++ b/vllm/attention/backends/pallas.py @@ -156,6 +156,7 @@ def forward( kv_cache: Tuple[torch.Tensor, torch.Tensor], attn_metadata: PallasMetadata, output: Optional[torch.Tensor] = None, + attn_type: AttentionType = AttentionType.DECODER, ) -> torch.Tensor: """Forward pass with Pallas attention. diff --git a/vllm/attention/backends/rocm_flash_attn.py b/vllm/attention/backends/rocm_flash_attn.py index ca6fa9ca61b30..69cab2099c52b 100644 --- a/vllm/attention/backends/rocm_flash_attn.py +++ b/vllm/attention/backends/rocm_flash_attn.py @@ -1,3 +1,4 @@ +# SPDX-License-Identifier: Apache-2.0 """Attention layer ROCm GPUs.""" from dataclasses import dataclass from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type @@ -21,7 +22,7 @@ logger = init_logger(__name__) -_PARTITION_SIZE_ROCM = 512 +_PARTITION_SIZE_ROCM = 256 _GPU_ARCH = torch.cuda.get_device_properties("cuda").gcnArchName _ON_NAVI = "gfx1" in _GPU_ARCH _ON_MI250_MI300 = any(arch in _GPU_ARCH @@ -90,6 +91,17 @@ class ROCmFlashAttentionMetadata(AttentionMetadata, PagedAttentionMetadata): seq_lens: Optional[List[int]] # seq_lens stored as a tensor. seq_lens_tensor: Optional[torch.Tensor] + # Maximum sequence length among prefill batch. 0 if there are decoding + # requests only. + max_prefill_seq_len: int + # Maximum sequence length among decode batch. 0 if there are prefill + # requests only. + max_decode_seq_len: int + + # Whether or not if cuda graph is enabled. + # Cuda-graph is currently enabled for decoding only. + # TODO(woosuk): Move `use_cuda_graph` out since it's unrelated to attention. + use_cuda_graph: bool # NOTE(sang): Definition of context_len, query_len, and seq_len. # |---------- N-1 iteration --------| @@ -100,30 +112,18 @@ class ROCmFlashAttentionMetadata(AttentionMetadata, PagedAttentionMetadata): # |-- query_len ---| # Maximum query length in the batch. None for decoding. - max_query_len: Optional[int] - # Maximum sequence length among prefill batch. 0 if there are decoding - # requests only. - max_prefill_seq_len: int - # Maximum sequence length among decode batch. 0 if there are prefill - # requests only. - max_decode_seq_len: int + max_query_len: Optional[int] = None # (batch_size + 1,). The cumulative subquery lengths of the sequences in # the batch, used to index into subquery. E.g., if the subquery length # is [4, 6], it is [0, 4, 10]. - query_start_loc: Optional[torch.Tensor] + query_start_loc: Optional[torch.Tensor] = None # (batch_size + 1,). The cumulative sequence lengths of the sequences in # the batch, used to index into sequence. E.g., if the sequence length is # [4, 6], it is [0, 4, 10]. - seq_start_loc: Optional[torch.Tensor] - - # Whether or not if cuda graph is enabled. - # Cuda-graph is currently enabled for decoding only. - # TODO(woosuk): Move `use_cuda_graph` out since it's unrelated to attention. - use_cuda_graph: bool - + seq_start_loc: Optional[torch.Tensor] = None # (batch_size,) A tensor of context lengths (tokens that are computed # so far). - context_lens_tensor: Optional[torch.Tensor] + context_lens_tensor: Optional[torch.Tensor] = None # Max number of query tokens among request in the batch. max_decode_query_len: Optional[int] = None @@ -131,6 +131,23 @@ class ROCmFlashAttentionMetadata(AttentionMetadata, PagedAttentionMetadata): _cached_prefill_metadata: Optional["ROCmFlashAttentionMetadata"] = None _cached_decode_metadata: Optional["ROCmFlashAttentionMetadata"] = None + # Begin encoder attn & enc/dec cross-attn fields... + + # Encoder sequence lengths representation + encoder_seq_lens: Optional[List[int]] = None + encoder_seq_lens_tensor: Optional[torch.Tensor] = None + + # Maximum sequence length among encoder sequences + max_encoder_seq_len: Optional[int] = None + + # Number of tokens input to encoder + num_encoder_tokens: Optional[int] = None + + # Cross-attention memory-mapping data structures: slot mapping + # and block tables + cross_slot_mapping: Optional[torch.Tensor] = None + cross_block_tables: Optional[torch.Tensor] = None + @property def prefill_metadata(self) -> Optional["ROCmFlashAttentionMetadata"]: if self.num_prefills == 0: @@ -141,10 +158,7 @@ def prefill_metadata(self) -> Optional["ROCmFlashAttentionMetadata"]: assert self.seq_lens is not None assert self.seq_lens_tensor is not None - assert self.query_start_loc is not None - assert self.context_lens_tensor is not None assert self.block_tables is not None - assert self.seq_start_loc is not None self._cached_prefill_metadata = ROCmFlashAttentionMetadata( num_prefills=self.num_prefills, @@ -159,12 +173,20 @@ def prefill_metadata(self) -> Optional["ROCmFlashAttentionMetadata"]: max_query_len=self.max_query_len, max_prefill_seq_len=self.max_prefill_seq_len, max_decode_seq_len=0, - query_start_loc=self.query_start_loc[:self.num_prefills + 1], - seq_start_loc=self.seq_start_loc[:self.num_prefills + 1], - context_lens_tensor=self.context_lens_tensor[:self.num_prefills], + query_start_loc=None if self.query_start_loc is None else + self.query_start_loc[:self.num_prefills + 1], + seq_start_loc=None if self.seq_start_loc is None else + self.seq_start_loc[:self.num_prefills + 1], + context_lens_tensor=None if self.context_lens_tensor is None else + self.context_lens_tensor[:self.num_prefills], block_tables=self.block_tables[:self.num_prefills], use_cuda_graph=False, - ) + # Begin encoder & cross attn fields below... + encoder_seq_lens=self.encoder_seq_lens, + encoder_seq_lens_tensor=self.encoder_seq_lens_tensor, + max_encoder_seq_len=self.max_encoder_seq_len, + cross_slot_mapping=self.cross_slot_mapping, + cross_block_tables=self.cross_block_tables) return self._cached_prefill_metadata @property @@ -194,7 +216,12 @@ def decode_metadata(self) -> Optional["ROCmFlashAttentionMetadata"]: context_lens_tensor=None, block_tables=self.block_tables[self.num_prefills:], use_cuda_graph=self.use_cuda_graph, - ) + # Begin encoder & cross attn fields below... + encoder_seq_lens=self.encoder_seq_lens, + encoder_seq_lens_tensor=self.encoder_seq_lens_tensor, + max_encoder_seq_len=self.max_encoder_seq_len, + cross_slot_mapping=self.cross_slot_mapping, + cross_block_tables=self.cross_block_tables) # Batch may be composed of prefill|decodes, adjust query start indices # to refer to the start of decodes when the two are split apart. # E.g. in tokens:[3 prefills|6 decodes], query_start_loc=[3,9] => [0,6]. @@ -304,6 +331,97 @@ def _make_alibi_bias(alibi_slopes: torch.Tensor, return attn_biases +def _get_seq_len_block_table_args( + attn_metadata: ROCmFlashAttentionMetadata, + attn_type: str, +) -> tuple: + ''' + The particular choice of sequence-length + attributes which should be extracted from attn_metadata is dependent + on the type of attention operation. + + Decoder attn -> select entirely decoder self-attention-related fields + Encoder/decoder cross-attn -> select encoder sequence lengths + Encoder attn -> select encoder sequence lengths fields + + Arguments: + + * attn_metadata: Attention metadata structure associated with attention op + * attn_type: encoder attention, decoder self-attention, + encoder/decoder cross-attention + + Returns: + + * Appropriate sequence-lengths tensors for query and key + * Appropriate max sequence-length scalar + ''' + + partial_prefix_sum = 0 + if attn_type == AttentionType.ENCODER: + assert attn_metadata.encoder_seq_lens is not None + assert attn_metadata.encoder_seq_lens_tensor is not None + query_seq_start_loc = torch.tensor( + [0] + [ + partial_prefix_sum := partial_prefix_sum + i + for i in attn_metadata.encoder_seq_lens + ], + device=attn_metadata.encoder_seq_lens_tensor.device, + dtype=attn_metadata.encoder_seq_lens_tensor.dtype) + causal_mask = False + + # No block tables associated with encoder attention + return (query_seq_start_loc, attn_metadata.max_encoder_seq_len, + query_seq_start_loc, attn_metadata.max_encoder_seq_len, + attn_metadata.encoder_seq_lens, causal_mask) + elif attn_type == AttentionType.DECODER: + # Decoder self-attention + # Choose max_seq_len based on whether we are in prompt_run + assert attn_metadata.seq_lens is not None + assert attn_metadata.seq_lens_tensor is not None + query_seq_start_loc = torch.tensor( + [0] + [ + partial_prefix_sum := partial_prefix_sum + i + for i in attn_metadata.seq_lens + ], + device=attn_metadata.seq_lens_tensor.device, + dtype=attn_metadata.seq_lens_tensor.dtype) + max_seq_len = attn_metadata.max_prefill_seq_len + causal_mask = True + + return (query_seq_start_loc, max_seq_len, query_seq_start_loc, + max_seq_len, attn_metadata.seq_lens, causal_mask) + elif attn_type == AttentionType.ENCODER_DECODER: + assert attn_metadata.seq_lens is not None + assert attn_metadata.encoder_seq_lens_tensor is not None + query_start_loc = torch.tensor( + [0] + [ + partial_prefix_sum := partial_prefix_sum + i + for i in attn_metadata.seq_lens + ], + device=attn_metadata.encoder_seq_lens_tensor.device, + dtype=attn_metadata.encoder_seq_lens_tensor.dtype) + + partial_prefix_sum = 0 + assert attn_metadata.encoder_seq_lens is not None + assert attn_metadata.seq_lens_tensor is not None + key_seq_start_loc = torch.tensor( + [0] + [ + partial_prefix_sum := partial_prefix_sum + i + for i in attn_metadata.encoder_seq_lens + ], + device=attn_metadata.seq_lens_tensor.device, + dtype=attn_metadata.seq_lens_tensor.dtype) + causal_mask = False + + # Enc/dec cross-attention KVs match encoder sequence length; + # cross-attention utilizes special "cross" block tables + return (query_start_loc, attn_metadata.max_prefill_seq_len, + key_seq_start_loc, attn_metadata.max_encoder_seq_len, + attn_metadata.seq_lens, causal_mask) + else: + raise AttributeError(f"Invalid attention type {str(attn_type)}") + + class ROCmFlashAttentionImpl(AttentionImpl): """ If the input tensors contain prompt tokens, the layout is as follows: @@ -346,10 +464,13 @@ def __init__( if blocksparse_params is not None: raise ValueError( "ROCmFlashAttention does not support blocksparse attention.") - if logits_soft_cap is not None: - raise ValueError( - "ROCmFlashAttention does not support attention logits soft " - "capping.") + + if logits_soft_cap is None: + # In flash-attn, setting logits_soft_cap as 0 means no soft cap. + self.logits_soft_cap = 0.0 + else: + self.logits_soft_cap = logits_soft_cap + self.attn_type = attn_type self.num_heads = num_heads self.head_size = head_size self.scale = float(scale) @@ -370,10 +491,18 @@ def __init__( f"Head size {head_size} is not supported by PagedAttention. " f"Supported head sizes are: {supported_head_sizes}.") - self.use_naive_attn = False + self.use_naive_attn = envs.VLLM_USE_SDPA_ATTENTION # Default False # NOTE: Allow for switching between Triton and CK. Defaulting to triton. self.use_triton_flash_attn = envs.VLLM_USE_TRITON_FLASH_ATTN if self.use_triton_flash_attn: + if logits_soft_cap is not None: + raise ValueError( + "ROCm Triton FlashAttention does not support attention" + "logits soft capping." + " please try using the ROCm CK " + "FA backend instead by setting the env var " + "`VLLM_USE_TRITON_FLASH_ATTN=0`") + from vllm.attention.ops.triton_flash_attention import ( # noqa: F401 triton_attention) self.attn_func = triton_attention @@ -398,14 +527,13 @@ def __init__( self.use_naive_attn = True if self.use_naive_attn: - self.attn_func = _sdpa_attention - logger.debug("Using naive attention in ROCmBackend") + if logits_soft_cap is not None: + raise ValueError( + "ROCm Naive FlashAttention does not support" + "attention logits soft capping.") - if attn_type != AttentionType.DECODER: - raise NotImplementedError("Encoder self-attention and " - "encoder/decoder cross-attention " - "are not implemented for " - "ROCmFlashAttentionImpl") + self.attn_func = _sdpa_attention + logger.debug("Using naive (SDPA) attention in ROCmBackend") def repeat_kv(self, x: torch.Tensor, n_rep: int) -> torch.Tensor: """torch.repeat_interleave(x, dim=1, repeats=n_rep)""" @@ -427,6 +555,37 @@ def forward( ) -> torch.Tensor: """Forward pass with FlashAttention and PagedAttention. + For decoder-only models: query, key and value must be non-None. + + For encoder/decoder models: + * ROCmFlashAttentionImpl.forward() may be invoked for both self- and + cross-attention layers. + * For self-attention: query, key and value must be non-None. + * For cross-attention: + * Query must be non-None + * During prefill, key and value must be non-None; key and value + get cached for use during decode. + * During decode, key and value may be None, since: + (1) key and value tensors were cached during prefill, and + (2) cross-attention key and value tensors do not grow during + decode + + A note on how the attn_type (attention type enum) argument impacts + attention forward() behavior: + + * DECODER: normal decoder-only behavior; + use decoder self-attention block table + * ENCODER: no KV caching; pass encoder sequence + attributes (encoder_seq_lens/encoder_seq_lens_tensor/ + max_encoder_seq_len) to kernel, in lieu of decoder + sequence attributes (seq_lens/seq_lens_tensor/max_seq_len) + * ENCODER_DECODER: cross-attention behavior; + use cross-attention block table for caching KVs derived + from encoder hidden states; since KV sequence lengths + will match encoder sequence lengths, pass encoder sequence + attributes to kernel (encoder_seq_lens/encoder_seq_lens_tensor/ + max_encoder_seq_len) + Args: query: shape = [num_tokens, num_heads * head_size] key: shape = [num_tokens, num_kv_heads * head_size] @@ -435,54 +594,81 @@ def forward( NOTE: kv_cache will be an empty tensor with shape [0] for profiling run. attn_metadata: Metadata for attention. + attn_type: Select attention type, between encoder attention, + decoder self-attention, or encoder/decoder cross- + attention. Defaults to decoder self-attention, + which is the vLLM default generally Returns: shape = [num_tokens, num_heads * head_size] """ - # Reminder: Please update docs/source/features/compatibility_matrix.md - # If the feature combo become valid - num_tokens, hidden_size = query.shape - # Reshape the query, key, and value tensors. + fp8_out_scale = None query = query.view(-1, self.num_heads, self.head_size) - key = key.view(-1, self.num_kv_heads, self.head_size) - value = value.view(-1, self.num_kv_heads, self.head_size) + if key is not None: + assert value is not None + key = key.view(-1, self.num_kv_heads, self.head_size) + value = value.view(-1, self.num_kv_heads, self.head_size) + else: + assert value is None - if kv_cache.numel() > 0: + if self.attn_type != AttentionType.ENCODER and kv_cache.numel() > 0: key_cache, value_cache = PagedAttention.split_kv_cache( kv_cache, self.num_kv_heads, self.head_size) - # Reshape the input keys and values and store them in the cache. - # If kv_cache is not provided, the new key and value tensors are - # not cached. This happens during the initial memory profiling run. - PagedAttention.write_to_paged_cache( - key, - value, - key_cache, - value_cache, - attn_metadata.slot_mapping, - self.kv_cache_dtype, - layer._k_scale, - layer._v_scale, - ) - - num_prefill_tokens = attn_metadata.num_prefill_tokens - num_decode_tokens = attn_metadata.num_decode_tokens - assert key.shape[0] == num_prefill_tokens + num_decode_tokens - assert value.shape[0] == num_prefill_tokens + num_decode_tokens + if key is not None and value is not None: + # Reshape the input keys and values and store them in the + # cache. If kv_cache is not provided, the new key and value + # tensors are not cached. This happens during the initial + # memory profiling run. + PagedAttention.write_to_paged_cache( + key, + value, + key_cache, + value_cache, + attn_metadata.slot_mapping + if self.attn_type != AttentionType.ENCODER_DECODER else + attn_metadata.cross_slot_mapping, + self.kv_cache_dtype, + layer._k_scale, + layer._v_scale, + ) + + if self.attn_type != AttentionType.ENCODER: + num_prefill_tokens = attn_metadata.num_prefill_tokens + else: + assert attn_metadata.num_encoder_tokens is not None + num_prefill_tokens = attn_metadata.num_encoder_tokens output = torch.empty_like(query) # Query for decode. KV is not needed because it is already cached. decode_query = query[num_prefill_tokens:] # QKV for prefill. query = query[:num_prefill_tokens] - key = key[:num_prefill_tokens] - value = value[:num_prefill_tokens] - assert query.shape[0] == num_prefill_tokens - assert decode_query.shape[0] == num_decode_tokens + if key is not None and value is not None \ + and self.attn_type != AttentionType.ENCODER_DECODER: + key = key[:num_prefill_tokens] + value = value[:num_prefill_tokens] if prefill_meta := attn_metadata.prefill_metadata: # Prompt run. - assert prefill_meta.seq_lens is not None + # normal attention and DECODER + if self.attn_type == AttentionType.DECODER and ( + kv_cache.numel() == 0 or prefill_meta.block_tables is None + or prefill_meta.block_tables.numel() == 0): + (query_seq_start_loc, query_max_seq_len, key_seq_start_loc, + key_max_seq_len, seq_lens, + causal_mask) = (prefill_meta.seq_start_loc, + prefill_meta.max_prefill_seq_len, + prefill_meta.seq_start_loc, + prefill_meta.max_prefill_seq_len, + attn_metadata.seq_lens, True) + # prefix-enabled attention and ENCODER/ENCODER_DECODER + else: + (query_seq_start_loc, query_max_seq_len, key_seq_start_loc, + key_max_seq_len, seq_lens, + causal_mask) = _get_seq_len_block_table_args( + prefill_meta, self.attn_type) + # Prompt run. if kv_cache.numel() == 0 or prefill_meta.block_tables.numel() == 0: # triton attention # When block_tables are not filled, it means q and k are the @@ -493,21 +679,29 @@ def forward( attn_masks = _make_alibi_bias( self.alibi_slopes, query.dtype, - attn_metadata.seq_lens, + seq_lens, make_attn_mask=False) # type: ignore + full_scales = ( + layer._q_scale.item(), layer._k_scale.item(), + layer._v_scale.item(), layer._prob_scale.item(), + fp8_out_scale.item()) if ( + fp8_out_scale and layer._q_scale + and layer._prob_scale + and envs.VLLM_USE_ROCM_FP8_FLASH_ATTN) else None out, _ = self.attn_func( query, key, value, None, - prefill_meta.seq_start_loc, - prefill_meta.seq_start_loc, - prefill_meta.max_prefill_seq_len, - prefill_meta.max_prefill_seq_len, - True, + query_seq_start_loc, + key_seq_start_loc, + query_max_seq_len, + key_max_seq_len, + causal_mask, self.scale, attn_masks[0][None] if attn_masks is not None else None, + full_scales, ) elif self.use_naive_attn: if self.num_kv_heads != self.num_heads: @@ -528,11 +722,12 @@ def forward( query, key, value, - prefill_meta.seq_lens, - num_tokens, + query_seq_start_loc, + num_prefill_tokens, self.num_heads, self.head_size, self.scale, + causal_mask, attn_masks, ) else: @@ -540,19 +735,23 @@ def forward( q=query, k=key, v=value, - cu_seqlens_q=prefill_meta.seq_start_loc, - cu_seqlens_k=prefill_meta.seq_start_loc, + cu_seqlens_q=query_seq_start_loc, + cu_seqlens_k=key_seq_start_loc, max_seqlen_q=prefill_meta.max_prefill_seq_len, - max_seqlen_k=prefill_meta.max_prefill_seq_len, + max_seqlen_k=key_max_seq_len, softmax_scale=self.scale, causal=True, window_size=self.sliding_window, alibi_slopes=self.alibi_slopes, + softcap=self.logits_soft_cap, ) # common code for prefill assert output[:num_prefill_tokens].shape == out.shape - output[:num_prefill_tokens] = out + if output.shape[0] > num_prefill_tokens: + output[:num_prefill_tokens] = out + else: + output = out else: # prefix-enabled attention output[:num_prefill_tokens] = PagedAttention.forward_prefix( @@ -583,7 +782,10 @@ def forward( decode_query.dtype, head_size, block_size, gqa_ratio, decode_meta.max_decode_seq_len) if use_custom: - max_seq_len = decode_meta.max_decode_seq_len + max_seq_len = (decode_meta.max_decode_seq_len if self.attn_type + != AttentionType.ENCODER_DECODER else + decode_meta.max_encoder_seq_len) + assert max_seq_len is not None max_num_partitions = ( (max_seq_len + _PARTITION_SIZE_ROCM - 1) // _PARTITION_SIZE_ROCM) @@ -599,8 +801,12 @@ def forward( device=output.device, ) max_logits = torch.empty_like(exp_sums) + if num_prefill_tokens > 0: + out = output[num_prefill_tokens:] + else: + out = output ops.paged_attention_rocm( - output[num_prefill_tokens:], + out, exp_sums, max_logits, tmp_output, @@ -609,23 +815,34 @@ def forward( value_cache, self.num_kv_heads, self.scale, - decode_meta.block_tables, - decode_meta.seq_lens_tensor, + decode_meta.block_tables + if self.attn_type != AttentionType.ENCODER_DECODER else + decode_meta.cross_block_tables, + decode_meta.seq_lens_tensor + if self.attn_type != AttentionType.ENCODER_DECODER else + decode_meta.encoder_seq_lens_tensor, block_size, max_seq_len, self.alibi_slopes, self.kv_cache_dtype, layer._k_scale, layer._v_scale, + _PARTITION_SIZE_ROCM, ) else: output[num_prefill_tokens:] = PagedAttention.forward_decode( decode_query, key_cache, value_cache, - decode_meta.block_tables, - decode_meta.seq_lens_tensor, - decode_meta.max_decode_seq_len, + decode_meta.block_tables + if self.attn_type != AttentionType.ENCODER_DECODER else + decode_meta.cross_block_tables, + decode_meta.seq_lens_tensor + if self.attn_type != AttentionType.ENCODER_DECODER else + decode_meta.encoder_seq_lens_tensor, + decode_meta.max_decode_seq_len + if self.attn_type != AttentionType.ENCODER_DECODER else + decode_meta.max_encoder_seq_len, self.kv_cache_dtype, self.num_kv_heads, self.scale, @@ -635,7 +852,7 @@ def forward( ) # Reshape the output tensor. - return output.view(num_tokens, hidden_size) + return output.view(-1, self.num_heads * self.head_size) def _sdpa_attention( @@ -647,6 +864,7 @@ def _sdpa_attention( num_heads: int, head_size: int, scale: float, + is_causal: bool, attn_masks: Optional[List[torch.Tensor]] = None, ) -> torch.Tensor: start = 0 @@ -664,7 +882,7 @@ def _sdpa_attention( key[:, start:end, :], value[:, start:end, :], dropout_p=0.0, - is_causal=attn_masks is None, + is_causal=is_causal, attn_mask=attn_masks[i] if attn_masks else None, scale=scale).movedim(query.dim() - 2, 0) output[start:end, :, :] = sub_out @@ -681,4 +899,5 @@ def _use_rocm_custom_paged_attention(qtype: torch.dtype, head_size: int, and (qtype == torch.half or qtype == torch.bfloat16) and (head_size == 64 or head_size == 128) and (block_size == 16 or block_size == 32) - and (gqa_ratio >= 1 and gqa_ratio <= 16) and max_seq_len <= 32768) + and (gqa_ratio >= 1 and gqa_ratio <= 16) + and max_seq_len <= 128 * 1024) diff --git a/vllm/attention/backends/torch_sdpa.py b/vllm/attention/backends/torch_sdpa.py index c3b2398b4e632..1d773d7b99d7f 100644 --- a/vllm/attention/backends/torch_sdpa.py +++ b/vllm/attention/backends/torch_sdpa.py @@ -441,6 +441,7 @@ def forward( kv_cache: torch.Tensor, attn_metadata: TorchSDPAMetadata, # type: ignore output: Optional[torch.Tensor] = None, + attn_type: AttentionType = AttentionType.DECODER, ) -> torch.Tensor: """Forward pass with torch SDPA and PagedAttention. diff --git a/vllm/attention/backends/utils.py b/vllm/attention/backends/utils.py index 84fe89b7df360..bee50f38df4ce 100644 --- a/vllm/attention/backends/utils.py +++ b/vllm/attention/backends/utils.py @@ -11,6 +11,7 @@ AttentionState) from vllm.attention.backends.abstract import AttentionType from vllm.multimodal import MultiModalPlaceholderMap +from vllm.platforms import current_platform from vllm.utils import async_tensor_h2d, make_tensor_with_pad if TYPE_CHECKING: @@ -229,9 +230,18 @@ def build(self, seq_lens: List[int], query_lens: List[int], # The shape of graph_block_tables is # [max batch size, max context len // block size]. input_block_tables = self.runner.graph_block_tables[:batch_size] + max_blocks = input_block_tables.shape[1] for i, block_table in enumerate(self.block_tables): if block_table: - input_block_tables[i, :len(block_table)] = block_table + num_blocks = len(block_table) + if num_blocks <= max_blocks: + input_block_tables[i, :num_blocks] = block_table + else: + # It may be possible to have more blocks allocated due + # to lookahead slots of multi-step, however, they are + # not used anyway, so can be safely ignored. + input_block_tables[ + i, :max_blocks] = block_table[:max_blocks] block_tables = torch.from_numpy(input_block_tables).to( device, non_blocking=True) else: @@ -332,15 +342,22 @@ def graph_capture_get_metadata_for_batch( use_cuda_graph=True, ) if is_encoder_decoder_model: - # The encoder decoder model works only with XFormers and - # Flash Attention backend. Assert the same. - assert self.runner.attn_backend.get_name() in\ - ["XFORMERS", "FLASH_ATTN"], \ - f"Expected attn_backend name to be either 'XFORMERS' or " \ - f"'FLASH_ATTN', but "\ - f"got '{self.runner.attn_backend.get_name()}'" - self._update_captured_metadata_for_enc_dec_model( - batch_size=batch_size, attn_metadata=attn_metadata) + # The encoder decoder model works only with XFormers backend. + # Assert the same. + if current_platform.is_rocm(): + assert (self.runner.attn_backend.get_name() == "ROCM_FLASH"), ( + f"Expected attn_backend name to be 'ROCM_FLASH', but " + f" got '{self.runner.attn_backend.get_name()}'") + self._update_captured_metadata_for_enc_dec_model( + batch_size=batch_size, attn_metadata=attn_metadata) + else: + assert self.runner.attn_backend.get_name() in\ + ["XFORMERS", "FLASH_ATTN"], \ + f"Expected attn_backend name to be either 'XFORMERS' or " \ + f"'FLASH_ATTN', but "\ + f"got '{self.runner.attn_backend.get_name()}'" + self._update_captured_metadata_for_enc_dec_model( + batch_size=batch_size, attn_metadata=attn_metadata) return attn_metadata @@ -356,13 +373,20 @@ def get_graph_input_buffers( if is_encoder_decoder_model: # The encoder decoder model works only with XFormers and # Flash Attention backend. Assert the same. - assert self.runner.attn_backend.get_name() in\ + if current_platform.is_rocm(): + assert (self.runner.attn_backend.get_name() == "ROCM_FLASH"), ( + f"Expected attn_backend name to be 'ROCM_FLASH', but " + f" got '{self.runner.attn_backend.get_name()}'") + self._add_additonal_input_buffers_for_enc_dec_model( + attn_metadata=attn_metadata, input_buffers=input_buffers) + else: + assert self.runner.attn_backend.get_name() in\ ["XFORMERS", "FLASH_ATTN"], \ - f"Expected attn_backend name to be either 'XFORMERS' or "\ + f"Expected attn_backend name to be either 'XFORMERS' or "\ f"'FLASH_ATTN', but "\ - f"got '{self.runner.attn_backend.get_name()}'" - self._add_additonal_input_buffers_for_enc_dec_model( - attn_metadata=attn_metadata, input_buffers=input_buffers) + f"got '{self.runner.attn_backend.get_name()}'" + self._add_additonal_input_buffers_for_enc_dec_model( + attn_metadata=attn_metadata, input_buffers=input_buffers) return input_buffers def prepare_graph_input_buffers( @@ -377,13 +401,21 @@ def prepare_graph_input_buffers( if is_encoder_decoder_model: # The encoder decoder model works only with XFormers and # Flash Attention backend. Assert the same. - assert self.runner.attn_backend.get_name() in\ + + if current_platform.is_rocm(): + assert (self.runner.attn_backend.get_name() == "ROCM_FLASH"), ( + f"Expected attn_backend name to be 'ROCM_FLASH', but " + f" got '{self.runner.attn_backend.get_name()}'") + self._prepare_input_buffers_for_enc_dec_model( + attn_metadata, input_buffers) + else: + assert self.runner.attn_backend.get_name() in\ ["XFORMERS", "FLASH_ATTN"], \ - f"Expected attn_backend name to be either 'XFORMERS' or "\ + f"Expected attn_backend name to be either 'XFORMERS' or "\ f"'FLASH_ATTN', but "\ - f"got '{self.runner.attn_backend.get_name()}'" - self._prepare_input_buffers_for_enc_dec_model( - attn_metadata, input_buffers) + f"got '{self.runner.attn_backend.get_name()}'" + self._prepare_input_buffers_for_enc_dec_model( + attn_metadata, input_buffers) def begin_forward(self, model_input) -> None: return diff --git a/vllm/attention/backends/xformers.py b/vllm/attention/backends/xformers.py index 49f47f9c8ded3..ede11d44a48fa 100644 --- a/vllm/attention/backends/xformers.py +++ b/vllm/attention/backends/xformers.py @@ -425,6 +425,7 @@ def forward( kv_cache: torch.Tensor, attn_metadata: "XFormersMetadata", output: Optional[torch.Tensor] = None, + attn_type: AttentionType = AttentionType.DECODER, ) -> torch.Tensor: """Forward pass with xFormers and PagedAttention. diff --git a/vllm/attention/layer.py b/vllm/attention/layer.py index 962c45a65ae23..e80fad77f29cc 100644 --- a/vllm/attention/layer.py +++ b/vllm/attention/layer.py @@ -76,6 +76,8 @@ def __init__( self.calculate_kv_scales = calculate_kv_scales self._k_scale = torch.tensor(1.0, dtype=torch.float32) self._v_scale = torch.tensor(1.0, dtype=torch.float32) + self._q_scale = torch.tensor(1.0, dtype=torch.float32) + self._prob_scale = torch.tensor(1.0, dtype=torch.float32) # We also keep the float32 versions of k/v_scale for attention # backends that don't support tensors (Flashinfer) @@ -115,11 +117,11 @@ def __init__( self.backend = backend_name_to_enum(attn_backend.get_name()) self.dtype = dtype - # For cuda-alike (CUDA and ROCM) and cpu platforms, we control how + # For cuda and cpu platforms, we control how # torch.compile works by registering the attention as one giant # opaque custom op. For other platforms, we directly call them # and let torch.compile handle them. - self.use_direct_call = not current_platform.is_cuda_alike( + self.use_direct_call = not current_platform.is_cuda( ) and not current_platform.is_cpu() self.use_output = attn_backend.accept_output_buffer @@ -137,6 +139,7 @@ def __init__( ).parallel_config.pipeline_parallel_size) ] + self.q_range = torch.tensor(envs.Q_SCALE_CONSTANT, dtype=torch.float32) self.k_range = torch.tensor(envs.K_SCALE_CONSTANT, dtype=torch.float32) self.v_range = torch.tensor(envs.V_SCALE_CONSTANT, dtype=torch.float32) @@ -150,7 +153,7 @@ def forward( ) -> torch.Tensor: if self.calculate_kv_scales and \ attn_metadata.enable_kv_scales_calculation: - self.calc_kv_scales(key, value) + self.calc_kv_scales(query, key, value) if self.use_output: output = torch.empty_like(query) hidden_size = query.size(-1) @@ -177,6 +180,13 @@ def forward( return torch.ops.vllm.unified_attention( query, key, value, self.layer_name) + def calc_kv_scales(self, query, key, value): + self._q_scale.copy_(torch.abs(query).max() / self.q_range) + self._k_scale.copy_(torch.abs(key).max() / self.k_range) + self._v_scale.copy_(torch.abs(value).max() / self.v_range) + # We only calculate the scales once + self.calculate_kv_scales = False + def calc_kv_scales(self, key, value): self._k_scale.copy_(torch.abs(key).max() / self.k_range) self._v_scale.copy_(torch.abs(value).max() / self.v_range) diff --git a/vllm/attention/ops/triton_flash_attention.py b/vllm/attention/ops/triton_flash_attention.py index ef04603f22b6e..dbb3faecaf949 100644 --- a/vllm/attention/ops/triton_flash_attention.py +++ b/vllm/attention/ops/triton_flash_attention.py @@ -24,6 +24,8 @@ import triton import triton.language as tl +from vllm.utils import is_navi + torch_dtype: tl.constexpr = torch.float16 @@ -104,6 +106,9 @@ def _attn_fwd_inner( ENABLE_DROPOUT: tl.constexpr, RETURN_ENCODED_SOFTMAX: tl.constexpr, PADDED_HEAD: tl.constexpr, + USE_FP8: tl.constexpr, + qk_scale, + p_descale, ): # loop over k, v, and update accumulator for start_n in range(block_min, block_max, BLOCK_N): @@ -145,6 +150,8 @@ def _attn_fwd_inner( qk = tl.where(causal_mask, qk, float("-inf")) # -- compute qk ---- qk += tl.dot(q, k) + if USE_FP8: + qk *= qk_scale if bias_ptr is not None: bias = load_fn(bias_ptr, False, MASK_STEPS and (n_extra_tokens != 0), "zero") @@ -196,7 +203,12 @@ def _attn_fwd_inner( l_i = l_i * alpha + l_ij # update m_i and l_i m_i = m_ij + + if USE_FP8: + p *= p_descale + acc += tl.dot(p.to(V_block_ptr.type.element_ty), v) + V_block_ptr = tl.advance(V_block_ptr, (BLOCK_N, 0)) K_block_ptr = tl.advance(K_block_ptr, (0, BLOCK_N)) if bias_ptr is not None: @@ -207,103 +219,182 @@ def _attn_fwd_inner( return acc, l_i, m_i -@triton.autotune( - configs=[ +def get_cdna_autotune_configs(): + return [ triton.Config( { - "BLOCK_M": 256, - "BLOCK_N": 64, - "waves_per_eu": 2, - "PRE_LOAD_V": False, + 'BLOCK_M': 256, + 'BLOCK_N': 64, + 'waves_per_eu': 2, + 'PRE_LOAD_V': False }, num_stages=1, - num_warps=8, - ), + num_warps=8), triton.Config( { - "BLOCK_M": 128, - "BLOCK_N": 128, - "waves_per_eu": 2, - "PRE_LOAD_V": False, + 'BLOCK_M': 128, + 'BLOCK_N': 128, + 'waves_per_eu': 2, + 'PRE_LOAD_V': False }, num_stages=1, - num_warps=4, - ), + num_warps=4), triton.Config( { - "BLOCK_M": 256, - "BLOCK_N": 128, - "waves_per_eu": 2, - "PRE_LOAD_V": False, + 'BLOCK_M': 256, + 'BLOCK_N': 128, + 'waves_per_eu': 2, + 'PRE_LOAD_V': False }, num_stages=1, - num_warps=8, - ), + num_warps=8), triton.Config( { - "BLOCK_M": 128, - "BLOCK_N": 64, - "waves_per_eu": 1, - "PRE_LOAD_V": False, + 'BLOCK_M': 128, + 'BLOCK_N': 64, + 'waves_per_eu': 1, + 'PRE_LOAD_V': False }, num_stages=1, - num_warps=4, - ), + num_warps=4), triton.Config( { - "BLOCK_M": 128, - "BLOCK_N": 64, - "waves_per_eu": 3, - "PRE_LOAD_V": True, + 'BLOCK_M': 128, + 'BLOCK_N': 64, + 'waves_per_eu': 3, + 'PRE_LOAD_V': True }, num_stages=1, - num_warps=4, - ), + num_warps=4), triton.Config( { - "BLOCK_M": 128, - "BLOCK_N": 64, - "waves_per_eu": 3, - "PRE_LOAD_V": False, + 'BLOCK_M': 128, + 'BLOCK_N': 64, + 'waves_per_eu': 3, + 'PRE_LOAD_V': False }, num_stages=1, - num_warps=4, - ), + num_warps=4), triton.Config( { - "BLOCK_M": 64, - "BLOCK_N": 64, - "waves_per_eu": 4, - "PRE_LOAD_V": False, + 'BLOCK_M': 64, + 'BLOCK_N': 64, + 'waves_per_eu': 4, + 'PRE_LOAD_V': False }, num_stages=1, - num_warps=8, - ), + num_warps=8), triton.Config( { - "BLOCK_M": 32, - "BLOCK_N": 32, - "waves_per_eu": 4, - "PRE_LOAD_V": False, + 'BLOCK_M': 32, + 'BLOCK_N': 32, + 'waves_per_eu': 4, + 'PRE_LOAD_V': False }, num_stages=1, - num_warps=8, - ), + num_warps=8), # TODO: This config fails with head_size not pow2 with data mismatches. # triton.Config({'BLOCK_M': 32, 'BLOCK_N': 16, 'waves_per_eu': 1, # 'PRE_LOAD_V': False}, num_stages=1, num_warps=4), + + # Fails in AccelerateAMDMatmul (Triton) assert when using FP8: + # triton.Config( + # { + # "BLOCK_M": 16, + # "BLOCK_N": 16, + # "waves_per_eu": 1, + # "PRE_LOAD_V": False, + # }, + # num_stages=1, + # num_warps=4, + # ), + ], ['IS_CAUSAL', 'dropout_p', 'BLOCK_DMODEL', 'USE_FP8'] + + +def get_rdna_autotune_configs(): + return [ + triton.Config( + { + 'BLOCK_M': 32, + 'BLOCK_N': 32, + 'waves_per_eu': 4, + 'PRE_LOAD_V': False + }, + num_stages=1, + num_warps=2), + triton.Config( + { + 'BLOCK_M': 32, + 'BLOCK_N': 32, + 'waves_per_eu': 2, + 'PRE_LOAD_V': False + }, + num_stages=1, + num_warps=2), + triton.Config( + { + 'BLOCK_M': 32, + 'BLOCK_N': 16, + 'waves_per_eu': 4, + 'PRE_LOAD_V': False + }, + num_stages=1, + num_warps=2), triton.Config( { - "BLOCK_M": 16, - "BLOCK_N": 16, - "waves_per_eu": 1, - "PRE_LOAD_V": False, + 'BLOCK_M': 32, + 'BLOCK_N': 16, + 'waves_per_eu': 2, + 'PRE_LOAD_V': False }, num_stages=1, - num_warps=4, - ), - ], - key=['IS_CAUSAL', 'dropout_p', 'BLOCK_DMODEL'], + num_warps=2), + # Fails in AccelerateAMDMatmul (Triton) assert when using FP8: + # triton.Config( + # { + # 'BLOCK_M': 16, + # 'BLOCK_N': 16, + # 'waves_per_eu': 4, + # 'PRE_LOAD_V': False + # }, + # num_stages=1, + # num_warps=2), + # triton.Config( + # { + # 'BLOCK_M': 16, + # 'BLOCK_N': 16, + # 'waves_per_eu': 2, + # 'PRE_LOAD_V': False + # }, + # num_stages=1, + # num_warps=2), + # # Fall-back config. + # triton.Config( + # { + # 'BLOCK_M': 16, + # 'BLOCK_N': 16, + # 'waves_per_eu': 1, + # 'PRE_LOAD_V': False + # }, + # num_stages=1, + # num_warps=2), + ], ['IS_CAUSAL', 'dropout_p', 'BLOCK_DMODEL', 'USE_FP8'] + + +def get_autotune_configs(): + if is_navi(): + return get_rdna_autotune_configs() + else: + return get_cdna_autotune_configs() + + +autotune_configs, autotune_keys = get_autotune_configs() + +float8_info = torch.finfo(torch.float8_e4m3fnuz) + +@triton.autotune( + configs=autotune_configs, + key=autotune_keys, ) @triton.jit def attn_fwd( @@ -312,28 +403,34 @@ def attn_fwd( V, bias, sm_scale, + q_scale, + k_scale, + v_scale, + p_scale, + p_descale, + o_descale, L, Out, - stride_qz, - stride_qh, - stride_qm, - stride_qk, - stride_kz, - stride_kh, - stride_kn, - stride_kk, - stride_vz, - stride_vh, - stride_vk, - stride_vn, - stride_oz, - stride_oh, - stride_om, - stride_on, - stride_bz, - stride_bh, - stride_bm, - stride_bn, + stride_qz: tl.int64, + stride_qh: tl.int64, + stride_qm: tl.int64, + stride_qk: tl.int64, + stride_kz: tl.int64, + stride_kh: tl.int64, + stride_kn: tl.int64, + stride_kk: tl.int64, + stride_vz: tl.int64, + stride_vh: tl.int64, + stride_vk: tl.int64, + stride_vn: tl.int64, + stride_oz: tl.int64, + stride_oh: tl.int64, + stride_om: tl.int64, + stride_on: tl.int64, + stride_bz: tl.int64, + stride_bh: tl.int64, + stride_bm: tl.int64, + stride_bn: tl.int64, cu_seqlens_q, cu_seqlens_k, dropout_p, @@ -349,11 +446,14 @@ def attn_fwd( IS_CAUSAL: tl.constexpr, BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr, + USE_FP8: tl.constexpr, BLOCK_N: tl.constexpr, PRE_LOAD_V: tl.constexpr, BIAS_TYPE: tl.constexpr, ENABLE_DROPOUT: tl.constexpr, RETURN_ENCODED_SOFTMAX: tl.constexpr, + FP8_MIN: tl.constexpr = float8_info.min, + FP8_MAX: tl.constexpr = float8_info.max, ): start_m = tl.program_id(0) off_h_q = tl.program_id(1) @@ -507,7 +607,12 @@ def attn_fwd( qk_scale = sm_scale * 1.44269504089 # Q is loaded once at the beginning and shared by all N blocks. q = load_fn(Q_block_ptr, True, padded_head, "zero") - q = (q * qk_scale).to(Q_block_ptr.type.element_ty) + if not USE_FP8: + q = (q * qk_scale).to(Q_block_ptr.type.element_ty) + acc_scale = 1.0 + else: + qk_scale *= q_scale * k_scale + acc_scale = p_scale * v_scale # Here we compute how many full and masked blocks we have. padded_block_k = n_extra_tokens != 0 @@ -562,6 +667,9 @@ def attn_fwd( ENABLE_DROPOUT, RETURN_ENCODED_SOFTMAX, padded_head, + USE_FP8, + qk_scale, + p_descale, ) block_min = block_max block_max = n_blocks * BLOCK_N @@ -608,8 +716,14 @@ def attn_fwd( ENABLE_DROPOUT, RETURN_ENCODED_SOFTMAX, padded_head, + USE_FP8, + qk_scale, + p_descale, ) # epilogue + + if USE_FP8: + acc *= acc_scale acc = acc / l_i[:, None] if ENABLE_DROPOUT: acc = acc / (1 - dropout_p) @@ -620,6 +734,9 @@ def attn_fwd( end_m_idx = (start_m + 1) * BLOCK_M start_m_idx = start_m * BLOCK_M causal_start_idx = seqlen_q - seqlen_k + if USE_FP8: + acc *= o_descale + acc = tl.clamp(acc, FP8_MIN, FP8_MAX) acc = acc.to(Out.type.element_ty) if IS_CAUSAL: # noqa: SIM102 if causal_start_idx > start_m_idx and causal_start_idx < end_m_idx: @@ -627,9 +744,9 @@ def attn_fwd( causal_start_idx, dtype=tl.int32) mask_m_offsets = start_m_idx + tl.arange(0, BLOCK_M) - out_ptrs_mask = (mask_m_offsets[:, None] - >= out_mask_boundary[None, :]) - z = 0.0 + out_ptrs_mask = (mask_m_offsets[:, None] >= + out_mask_boundary[None, :]) + z = tl.zeros((1, ), tl.float32) acc = tl.where(out_ptrs_mask, acc, z.to(acc.type.element_ty)) # write back LSE # l_ptrs = L + off_z * HQ * MAX_SEQLENS_Q + off_h_q * MAX_SEQLENS_Q + offs_m @@ -710,7 +827,30 @@ def forward( causal=False, sm_scale=1.0, bias=None, + fp8_scales=None, ): + if fp8_scales is not None: + use_fp8 = True + (q_scale, k_scale, v_scale, p_scale, o_scale) = fp8_scales + float8 = torch.float8_e4m3fnuz + + def check_and_convert(t, scale): + if t.dtype != float8: + finfo = torch.finfo(float8) + descale = 1.0 / scale + ts = (t * descale).clamp(min=float8_info.min, + max=float8_info.max) + return ts.to(float8) + else: + return t + + q = check_and_convert(q, q_scale) + k = check_and_convert(k, k_scale) + v = check_and_convert(v, v_scale) + else: + use_fp8 = False + q_scale = k_scale = v_scale = p_scale = o_scale = 1.0 + if o is None: o = torch.empty_like(q, dtype=v.dtype) @@ -773,12 +913,24 @@ def forward( else: bias_strides = (0, 0, 0, 0) + p_descale = 1.0 / p_scale + o_descale = 1.0 / o_scale + + arg_max_seqlens_q = 0 if is_navi() else max_seqlens_q + arg_max_seqlens_k = 0 if is_navi() else max_seqlens_k + attn_fwd[grid]( q, k, v, bias, sm_scale, + q_scale, + k_scale, + v_scale, + p_scale, + p_descale, + o_descale, None, o, *q_strides, @@ -795,14 +947,15 @@ def forward( HQ=nheads_q, HK=nheads_k, ACTUAL_BLOCK_DMODEL=head_size, - MAX_SEQLENS_Q=max_seqlens_q, - MAX_SEQLENS_K=max_seqlens_k, + MAX_SEQLENS_Q=arg_max_seqlens_q, + MAX_SEQLENS_K=arg_max_seqlens_k, IS_CAUSAL=causal, VARLEN=True, BLOCK_DMODEL=padded_d_model, BIAS_TYPE=0 if bias is None else 1, ENABLE_DROPOUT=False, RETURN_ENCODED_SOFTMAX=False, + USE_FP8=use_fp8, ) ctx.grid = grid diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index ba96484e3fce9..ee861be4a1c32 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -193,6 +193,7 @@ class EngineArgs: worker_cls: str = "auto" kv_transfer_config: Optional[KVTransferConfig] = None + calculate_kv_scales: Optional[bool] = None generation_config: Optional[str] = None enable_sleep_mode: bool = False @@ -931,6 +932,15 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: default="auto", help='The worker class to use for distributed execution.') + parser.add_argument( + '--calculate-kv-scales', + action='store_true', + help='This enables dynamic calculation of ' + 'k_scale and v_scale when kv-cache-dtype is fp8. ' + 'If calculate-kv-scales is false, the scales will ' + 'be loaded from the model checkpoint if available. ' + 'Otherwise, the scales will default to 1.0.') + parser.add_argument( "--generation-config", type=nullable_str, diff --git a/vllm/envs.py b/vllm/envs.py index 8627caec7790d..d944a47048ced 100644 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -73,8 +73,14 @@ VLLM_ENABLE_V1_MULTIPROCESSING: bool = True VLLM_LOG_BATCHSIZE_INTERVAL: float = -1 VLLM_DISABLE_COMPILE_CACHE: bool = False - K_SCALE_CONSTANT: int = 200 - V_SCALE_CONSTANT: int = 100 + VLLM_USE_SDPA_ATTENTION: bool = False + VLLM_USE_ROCM_FP8_FLASH_ATTN: bool = False + VLLM_USE_ROCM_CUSTOM_PAGED_ATTN: bool = False + VLLM_USE_ROCM_CUSTOM_PAGED_ATTN_FP8_OUT: bool = False + Q_SCALE_CONSTANT: int = 20 + K_SCALE_CONSTANT: int = 20 + V_SCALE_CONSTANT: int = 10 + VLLM_FP8_PADDING: bool = True VLLM_SERVER_DEV_MODE: bool = False VLLM_V1_OUTPUT_PROC_CHUNK_SIZE: int = 128 @@ -478,11 +484,11 @@ def maybe_convert_int(value: Optional[str]) -> Optional[int]: # Divisor for dynamic key scale factor calculation for FP8 KV Cache "K_SCALE_CONSTANT": - lambda: int(os.getenv("K_SCALE_CONSTANT", "200")), + lambda: int(os.getenv("K_SCALE_CONSTANT", "20")), # Divisor for dynamic value scale factor calculation for FP8 KV Cache "V_SCALE_CONSTANT": - lambda: int(os.getenv("V_SCALE_CONSTANT", "100")), + lambda: int(os.getenv("V_SCALE_CONSTANT", "10")), # If set, enable multiprocessing in LLM for the V1 code path. "VLLM_ENABLE_V1_MULTIPROCESSING": lambda: bool(int(os.getenv("VLLM_ENABLE_V1_MULTIPROCESSING", "1"))), @@ -491,6 +497,30 @@ def maybe_convert_int(value: Optional[str]) -> Optional[int]: "VLLM_DISABLE_COMPILE_CACHE": lambda: bool(int(os.getenv("VLLM_DISABLE_COMPILE_CACHE", "0"))), + # flag to control if vllm should use naive scaled dot-product attention + "VLLM_USE_SDPA_ATTENTION": + lambda: (os.environ.get("VLLM_USE_SDPA_ATTENTION", "False").lower() in + ("true", "1")), + + # use quantized q,k,v,softmax(qk^T), attn output during prefill + "VLLM_USE_ROCM_FP8_FLASH_ATTN": + lambda: (os.getenv("VLLM_USE_ROCM_FP8_FLASH_ATTN", "False").lower() in + ("true", "1")), + + # have custom paged attention implemented for MI3* cards write out fp8 + "VLLM_USE_ROCM_CUSTOM_PAGED_ATTN_FP8_OUT": + lambda: + (os.getenv("VLLM_USE_ROCM_CUSTOM_PAGED_ATTN_FP8_OUT", "True").lower() in + ("true", "1")), + + # Divisor for dynamic query scale factor calculation for FP8 attention + "Q_SCALE_CONSTANT": + lambda: int(os.getenv("Q_SCALE_CONSTANT", "20")), + + # Pad the weight for moe kernel or not + "VLLM_FP8_PADDING": + lambda: bool(int(os.getenv("VLLM_FP8_PADDING", "1"))), + # If set, vllm will run in development mode, which will enable # some additional endpoints for developing and debugging, # e.g. `/reset_prefix_cache` diff --git a/vllm/model_executor/layers/activation.py b/vllm/model_executor/layers/activation.py index fb9684ac1c184..14b9d8cf53564 100644 --- a/vllm/model_executor/layers/activation.py +++ b/vllm/model_executor/layers/activation.py @@ -74,7 +74,10 @@ def forward_native(self, x: torch.Tensor) -> torch.Tensor: d = x.shape[-1] // 2 return F.silu(x[..., :d]) * x[..., d:] - def forward_cuda(self, x: torch.Tensor) -> torch.Tensor: + def forward_cuda(self, + x: torch.Tensor, + scale: Optional[torch.Tensor] = None) -> torch.Tensor: + d = x.shape[-1] // 2 output_shape = (x.shape[:-1] + (d, )) out = torch.empty(output_shape, dtype=x.dtype, device=x.device) diff --git a/vllm/model_executor/layers/layernorm.py b/vllm/model_executor/layers/layernorm.py index 43ea4eb5a4d1a..15f61466611ad 100644 --- a/vllm/model_executor/layers/layernorm.py +++ b/vllm/model_executor/layers/layernorm.py @@ -76,12 +76,24 @@ def forward_cuda( self, x: torch.Tensor, residual: Optional[torch.Tensor] = None, + scale: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: if self.variance_size_override is not None: return self.forward_native(x, residual) from vllm import _custom_ops as ops + if scale is not None: + out = torch.empty_like(x, dtype=torch.float8_e4m3fnuz) + if residual is not None: + ops.scaled_fused_add_rms_norm(out, x, residual, + self.weight.data, scale, + self.variance_epsilon) + return out, residual + ops.scaled_rms_norm(out, x, self.weight.data, scale, + self.variance_epsilon) + return out + if residual is not None: ops.fused_add_rms_norm( x, diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index 21d4355b36ab0..5ba574d34010a 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -1,6 +1,7 @@ from typing import Any, Callable, Dict, List, Optional import torch +import torch.nn.functional as F from torch.nn import Module from torch.nn.parameter import Parameter @@ -28,6 +29,7 @@ PerTensorScaleParameter) from vllm.model_executor.utils import set_weight_attrs from vllm.platforms import current_platform +from vllm.utils import is_navi ACTIVATION_SCHEMES = ["static", "dynamic"] @@ -111,6 +113,26 @@ def get_quant_method(self, layer: torch.nn.Module, return Fp8KVCacheMethod(self) return None + def get_cache_scale(self, name: str) -> Optional[str]: + """ + Check whether the param name matches the format for k/v cache scales + in compressed-tensors. If this is the case, return its equivalent + param name expected by vLLM + + :param name: param name + :return: matching param name for KV cache scale in vLLM + """ + if name.endswith(".output_scale") and ".k_proj" in name: + return name.replace(".k_proj.output_scale", ".attn.k_scale") + if name.endswith(".output_scale") and ".v_proj" in name: + return name.replace(".v_proj.output_scale", ".attn.v_scale") + if name.endswith(".output_scale") and ".q_proj" in name: + return name.replace(".q_proj.output_scale", ".attn.q_scale") + if name.endswith("self_attn.prob_output_scale"): + return name.replace(".prob_output_scale", ".attn.prob_scale") + # If no matches, return None + return None + class Fp8LinearMethod(LinearMethodBase): """Linear method for FP8. @@ -136,6 +158,7 @@ def __init__(self, quant_config: Fp8Config): # For GPUs that lack FP8 hardware support, we can leverage the Marlin # kernel for fast weight-only FP8 quantization + self.out_dtype = torch.get_default_dtype() self.use_marlin = (not current_platform.has_device_capability(89) or envs.VLLM_TEST_FORCE_FP8_MARLIN) # Disable marlin for rocm @@ -161,6 +184,8 @@ def create_weights( weight_loader = extra_weight_attrs.get("weight_loader") if self.block_quant: + assert not envs.VLLM_FP8_PADDING, ( + "FP8 weight padding is not supported in block quantization.") tp_size = get_tensor_model_parallel_world_size() assert self.quant_config.weight_block_size is not None block_n, block_k = ( @@ -247,7 +272,7 @@ def create_weights( def process_weights_after_loading(self, layer: Module) -> None: # Block quant doesn't need to process weights after loading if self.block_quant: - if current_platform.is_rocm(): + if current_platform.is_rocm() and not is_navi(): weight, weight_scale, _ = \ normalize_e4m3fn_to_e4m3fnuz( weight=layer.weight, @@ -280,9 +305,13 @@ def process_weights_after_loading(self, layer: Module) -> None: # If checkpoint is fp8, handle that there are N scales for N # shards in a fused module else: + layer.weight_scale.data[layer.weight_scale.data == torch.finfo( + torch.float32).min] = 1 layer.weight_scale = torch.nn.Parameter(layer.weight_scale.data, requires_grad=False) if self.quant_config.activation_scheme == "static": + layer.input_scale.data[layer.input_scale.data == torch.finfo( + torch.float32).min] = 1 layer.input_scale = torch.nn.Parameter(layer.input_scale.data, requires_grad=False) # If using marlin (w8a16), kernel uses channelwise weights, @@ -299,8 +328,8 @@ def process_weights_after_loading(self, layer: Module) -> None: weight = layer.weight weight_scale = layer.weight_scale - # If rocm, use float8_e4m3fnuz. - if current_platform.is_rocm(): + # If rocm (except Navi4x), use float8_e4m3fnuz. + if current_platform.is_rocm() and not is_navi(): weight, weight_scale, input_scale = \ normalize_e4m3fn_to_e4m3fnuz( weight=weight, @@ -316,6 +345,14 @@ def process_weights_after_loading(self, layer: Module) -> None: logical_widths=layer.logical_widths, ) + # Pad the weight + if envs.VLLM_FP8_PADDING and weight.stride(-1) == 1 \ + and (weight.stride(-2) * weight.element_size()) % 512 == 0: + num_pad = 256 // weight.element_size() + weight = F.pad(weight, (0, num_pad), "constant", + 0)[..., :-num_pad] + torch.cuda.empty_cache() + # Update layer with new values. layer.weight = Parameter(weight.t(), requires_grad=False) layer.weight_scale = Parameter(weight_scale, requires_grad=False) @@ -361,6 +398,7 @@ def apply(self, input=x, weight=layer.weight, weight_scale=layer.weight_scale, + out_dtype=self.out_dtype, input_scale=layer.input_scale, bias=bias, cutlass_fp8_supported=self.cutlass_fp8_supported, @@ -509,7 +547,7 @@ def create_weights(self, layer: Module, num_experts: int, hidden_size: int, def process_weights_after_loading(self, layer: Module) -> None: # Block quant doesn't need to process weights after loading if self.block_quant: - if current_platform.is_rocm(): + if current_platform.is_rocm() and not is_navi(): w13_weight, w13_weight_scale_inv, w13_input_scale = \ normalize_e4m3fn_to_e4m3fnuz( layer.w13_weight, layer.w13_weight_scale_inv, @@ -536,9 +574,9 @@ def process_weights_after_loading(self, layer: Module) -> None: return # If checkpoint is fp16, quantize in place. if not self.quant_config.is_checkpoint_fp8_serialized: - # If rocm, use float8_e4m3fnuz as dtype - fp8_dtype = torch.float8_e4m3fnuz \ - if current_platform.is_rocm() else torch.float8_e4m3fn + # If rocm (except Navi4x), use float8_e4m3fnuz as dtype + fp8_dtype = (torch.float8_e4m3fnuz if current_platform.is_rocm() + and not is_navi() else torch.float8_e4m3fn) w13_weight = torch.empty_like(layer.w13_weight.data, dtype=fp8_dtype) w2_weight = torch.empty_like(layer.w2_weight.data, dtype=fp8_dtype) @@ -585,8 +623,9 @@ def process_weights_after_loading(self, layer: Module) -> None: layer.w13_input_scale.max(), requires_grad=False) layer.w2_input_scale = torch.nn.Parameter( layer.w2_input_scale.max(), requires_grad=False) - # If rocm, normalize the weights and scales to e4m3fnuz - if current_platform.is_rocm(): + # If rocm (except Navi4x, which uses e4m3fn), + # normalize the weights and scales to e4m3fnuz + if current_platform.is_rocm() and not is_navi(): # Normalize the weights and scales w13_weight, w13_weight_scale, w13_input_scale = \ normalize_e4m3fn_to_e4m3fnuz( diff --git a/vllm/model_executor/layers/quantization/kv_cache.py b/vllm/model_executor/layers/quantization/kv_cache.py index e1870c73cc932..ef54c25a05573 100644 --- a/vllm/model_executor/layers/quantization/kv_cache.py +++ b/vllm/model_executor/layers/quantization/kv_cache.py @@ -1,9 +1,11 @@ import torch +import vllm.envs as envs from vllm.logger import init_logger from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig, QuantizeMethodBase) from vllm.platforms import current_platform +from vllm.utils import is_navi logger = init_logger(__name__) @@ -33,57 +35,91 @@ def create_weights(self, layer: torch.nn.Module): requires_grad=False) layer.v_scale = torch.nn.Parameter(torch.tensor(-1.0), requires_grad=False) + # Initialize Q and P = softmax(QK^T) scales + layer.q_scale = torch.nn.Parameter(torch.tensor(-1.0), + requires_grad=False) + layer.prob_scale = torch.nn.Parameter(torch.tensor(-1.0), + requires_grad=False) def apply(self, layer: torch.nn.Module) -> torch.Tensor: raise RuntimeError( f"{self.__class__.__name__}.apply should not be called.") def process_weights_after_loading(self, layer: torch.nn.Module) -> None: - # If the kv-cache dtype is auto, we enforce the k/v_scale to be 1.0 - # regardless whether the kv-scale is available in the checkpoint. - # No need to process kv scales after loading if we are going to - # calculate them on the fly. - if layer.kv_cache_dtype != "auto" and not layer.calculate_kv_scales: - if layer.k_scale > 0.0 and layer.v_scale > 0.0: - # We prefer to use separate k_scale and v_scale if present - k_scale = layer.k_scale.to("cpu").tolist() - v_scale = layer.v_scale.to("cpu").tolist() - if current_platform.is_rocm(): - k_scale *= 2 - v_scale *= 2 - elif layer.k_scale < 0.0 and layer.v_scale < 0.0: - # If no scales were loaded (both scales are invalid negative - # values), use the default value of 1.0 - k_scale = 1.0 - v_scale = 1.0 - else: - # If we find a single kv_scale in the checkpoint, we remap - # kv_scale to k_scale during weight loading, and duplicate - # k_scale to v_scale here - assert layer.k_scale > 0.0 - scale_to_duplicate = max(layer.k_scale, layer.v_scale) - k_scale = scale_to_duplicate.to("cpu").tolist() - v_scale = scale_to_duplicate.to("cpu").tolist() - if current_platform.is_rocm(): - k_scale *= 2 - v_scale *= 2 + if layer.k_scale > 0.0 and layer.v_scale > 0.0: + # We prefer to use separate k_scale and v_scale if present + k_scale = layer.k_scale.to("cpu").tolist() + v_scale = layer.v_scale.to("cpu").tolist() + if current_platform.is_rocm() and not is_navi(): + k_scale *= 2 + v_scale *= 2 + layer.calculate_kv_scales = False + elif layer.k_scale < 0.0 and layer.v_scale < 0.0: + # If no scales were loaded (both scales are invalid negative + # values), use the default value of 1.0 + k_scale = 1.0 + v_scale = 1.0 + else: + # If we find a single kv_scale in the checkpoint, we remap + # kv_scale to k_scale during weight loading, and duplicate + # k_scale to v_scale here + assert layer.k_scale > 0.0 + scale_to_duplicate = max(layer.k_scale, layer.v_scale) + k_scale = scale_to_duplicate.to("cpu").tolist() + v_scale = scale_to_duplicate.to("cpu").tolist() + if current_platform.is_rocm() and not is_navi(): + k_scale *= 2 + v_scale *= 2 + layer.calculate_kv_scales = False + + if not isinstance(k_scale, float) or not isinstance(v_scale, float): + raise ValueError("Only support per-tensor scaling factor " + "for fp8 KV cache") + + # These are used in the final Attention.forward() + layer._k_scale.copy_(k_scale) + layer._v_scale.copy_(v_scale) + layer._k_scale_float = k_scale + layer._v_scale_float = v_scale + if (k_scale == 1.0 and v_scale == 1.0 + and (layer.kv_cache_dtype != "auto" + or envs.VLLM_USE_ROCM_FP8_FLASH_ATTN) + and "e5m2" not in layer.kv_cache_dtype): + logger.warning_once( + "Using KV cache scaling factor 1.0 for fp8_e4m3. This " + "may cause accuracy issues. Please make sure k/v_scale " + "scaling factors are available in the fp8 checkpoint.") + + if layer.q_scale > 0.0: + q_scale = layer.q_scale.to("cpu").tolist() + if current_platform.is_rocm() and not is_navi(): + q_scale *= 2 + layer.calculate_kv_scales = False + else: + q_scale = 1.0 + if layer.prob_scale > 0.0: + prob_scale = layer.prob_scale.to("cpu").tolist() + if current_platform.is_rocm() and not is_navi(): + prob_scale *= 2 + else: + prob_scale = 1.0 - if not isinstance(k_scale, float) or not isinstance( - v_scale, float): - raise ValueError("Only support per-tensor scaling factor " - "for fp8 KV cache") + if not isinstance(q_scale, float) or not isinstance(prob_scale, float): + raise ValueError("Only support per-tensor scaling factor" + "for fp8-quantized Q/prob") - # These are used in the final Attention.forward() - layer._k_scale.copy_(k_scale) - layer._v_scale.copy_(v_scale) - layer._k_scale_float = k_scale - layer._v_scale_float = v_scale - if (k_scale == 1.0 and v_scale == 1.0 - and "e5m2" not in layer.kv_cache_dtype): - logger.warning_once( - "Using KV cache scaling factor 1.0 for fp8_e4m3. This " - "may cause accuracy issues. Please make sure k/v_scale " - "scaling factors are available in the fp8 checkpoint.") + # These are used in the final Attention.forward() + layer._q_scale.copy_(q_scale) + layer._prob_scale.copy_(prob_scale) + if (q_scale == 1.0 + or prob_scale == 1.0) and envs.VLLM_USE_ROCM_FP8_FLASH_ATTN: + logger.warning_once( + f"Using Q scale {q_scale} and prob scale {prob_scale} " + "with fp8 attention. This may cause accuracy issues. " + "Please make sure Q/prob scaling factors are " + "available in the fp8 checkpoint.") del layer.k_scale del layer.v_scale + del layer.q_scale + del layer.prob_scale diff --git a/vllm/model_executor/layers/quantization/quark/quark.py b/vllm/model_executor/layers/quantization/quark/quark.py index fc214255eca71..144036814fafa 100644 --- a/vllm/model_executor/layers/quantization/quark/quark.py +++ b/vllm/model_executor/layers/quantization/quark/quark.py @@ -1,5 +1,4 @@ import fnmatch -import re from typing import Any, Dict, List, Optional, cast import torch @@ -122,6 +121,12 @@ def from_config(cls, config: Dict[str, Any]) -> "QuarkConfig": for q_config in q_configs: q_config["output_tensors"] = None + # In case q_proj output is also quantized, remove the configuration + # to keep qkv consistency. + q_proj_q_config = cast(Dict[str, Any], + layer_quant_config.get("*q_proj")) + q_proj_q_config["output_tensors"] = None + return cls(quant_config=config, kv_cache_group=kv_cache_group, kv_cache_config=kv_cache_config, @@ -148,6 +153,19 @@ def _check_scheme_supported(self, else: return False + def is_fp8_w8a8(self) -> bool: + # Returns True if all quantized layers in model are fp8 w8a8 + global_quant_config = cast( + Dict[str, Any], self.quant_config.get("global_quant_config")) + layer_quant_configs = cast(Dict[str, Any], + self.quant_config.get("layer_quant_config")) + for config in (global_quant_config, *layer_quant_configs.values()): + weight_config = cast(Dict[str, Any], config.get("weight")) + input_config = cast(Dict[str, Any], config.get("input_tensors")) + if not self._is_fp8_w8a8(weight_config, input_config): + return False + return True + def _is_fp8_w8a8(self, weight_quant: Optional[Dict[str, Any]], input_quant: Optional[Dict[str, Any]]) -> bool: # Confirm weights and input quantized. @@ -286,25 +304,14 @@ def get_cache_scale(self, name: str) -> Optional[str]: :param name: param name :return: matching param name for KV cache scale in vLLM """ - if self.kv_cache_group is None or len(self.kv_cache_group) == 0: - return None - - kv_proj_names = [ - re.split(r"[*.]", kv_cache)[-1] for kv_cache in self.kv_cache_group - ] - if name.endswith(".output_scale"): - if len(kv_proj_names) == 1 and kv_proj_names[0] in name: - kv_output_scale_name = "." + kv_proj_names[0] + ".output_scale" - return name.replace(kv_output_scale_name, ".attn.k_scale") - - elif len(kv_proj_names) == 2: - for kv_proj_name in kv_proj_names: - if kv_proj_name in name and kv_proj_name == "k_proj": - return name.replace(".k_proj.output_scale", - ".attn.k_scale") - elif kv_proj_name in name and kv_proj_name == "v_proj": - return name.replace(".v_proj.output_scale", - ".attn.v_scale") + if name.endswith(".output_scale") and ".k_proj" in name: + return name.replace(".k_proj.output_scale", ".attn.k_scale") + if name.endswith(".output_scale") and ".v_proj" in name: + return name.replace(".v_proj.output_scale", ".attn.v_scale") + if name.endswith(".output_scale") and ".q_proj" in name: + return name.replace(".q_proj.output_scale", ".attn.q_scale") + if name.endswith("self_attn.prob_output_scale"): + return name.replace(".prob_output_scale", ".attn.prob_scale") # If no matches, return None return None diff --git a/vllm/model_executor/layers/quantization/quark/schemes/quark_w8a8_fp8.py b/vllm/model_executor/layers/quantization/quark/schemes/quark_w8a8_fp8.py index 206931ea2ffc0..447911a648639 100644 --- a/vllm/model_executor/layers/quantization/quark/schemes/quark_w8a8_fp8.py +++ b/vllm/model_executor/layers/quantization/quark/schemes/quark_w8a8_fp8.py @@ -21,6 +21,7 @@ def __init__(self, qscheme: str, is_static_input_scheme: Optional[bool]): self.qscheme = qscheme self.is_static_input_scheme = is_static_input_scheme self.cutlass_fp8_supported = cutlass_fp8_supported() + self.out_dtype = torch.get_default_dtype() @classmethod def get_min_capability(cls) -> int: @@ -134,6 +135,7 @@ def apply_weights(self, input=x, weight=layer.weight, weight_scale=layer.weight_scale, + out_dtype=self.out_dtype, input_scale=layer.input_scale, bias=bias, cutlass_fp8_supported=self.cutlass_fp8_supported, diff --git a/vllm/model_executor/layers/quantization/utils/fp8_utils.py b/vllm/model_executor/layers/quantization/utils/fp8_utils.py index 43b1997019107..e96f78c19e439 100644 --- a/vllm/model_executor/layers/quantization/utils/fp8_utils.py +++ b/vllm/model_executor/layers/quantization/utils/fp8_utils.py @@ -6,6 +6,7 @@ import triton.language as tl from vllm.platforms import current_platform +from vllm.utils import is_navi def apply_w8a8_block_fp8_linear( @@ -41,8 +42,8 @@ def input_to_float8( """This function quantizes input values to float8 values " "with tensor-wise quantization.""" if dtype is None: - dtype = (torch.float8_e4m3fnuz - if current_platform.is_rocm() else torch.float8_e4m3fn) + dtype = (torch.float8_e4m3fnuz if current_platform.is_rocm() + and not is_navi() else torch.float8_e4m3fn) finfo = torch.finfo(dtype) min_val, max_val = x.aminmax() amax = torch.maximum(min_val.abs(), max_val.abs()).clamp(min=1e-12) @@ -147,8 +148,8 @@ def per_token_group_quant_fp8( scaling factor for quantization. """ if dtype is None: - dtype = (torch.float8_e4m3fnuz - if current_platform.is_rocm() else torch.float8_e4m3fn) + dtype = (torch.float8_e4m3fnuz if current_platform.is_rocm() + and not is_navi() else torch.float8_e4m3fn) assert (x.shape[-1] % group_size == 0), ( f"the last dimension of `x` {x.shape[-1]} must be divisible " f"by `group_size` {group_size}") diff --git a/vllm/model_executor/layers/quantization/utils/w8a8_utils.py b/vllm/model_executor/layers/quantization/utils/w8a8_utils.py index 9977804188a50..c93a3951731e8 100644 --- a/vllm/model_executor/layers/quantization/utils/w8a8_utils.py +++ b/vllm/model_executor/layers/quantization/utils/w8a8_utils.py @@ -94,6 +94,7 @@ def apply_fp8_linear( input: torch.Tensor, weight: torch.Tensor, weight_scale: torch.Tensor, + out_dtype: Optional[torch.dtype] = None, input_scale: Optional[torch.Tensor] = None, input_scale_ub: Optional[torch.Tensor] = None, bias: Optional[torch.Tensor] = None, @@ -108,6 +109,9 @@ def apply_fp8_linear( input_2d = input.view(-1, input.shape[-1]) output_shape = [*input.shape[:-1], weight.shape[1]] + if out_dtype is None: + out_dtype = input.dtype + # cutlass_scaled_mm supports per tensor/channel W and per tensor/token A if cutlass_fp8_supported: qinput, x_scale = ops.scaled_fp8_quant( @@ -131,11 +135,14 @@ def apply_fp8_linear( # Note: we pad the input because torch._scaled_mm is more performant # for matrices with batch dimension > 16. # This could change in the future. - qinput, x_scale = ops.scaled_fp8_quant( - input_2d, - input_scale, - num_token_padding=17, - use_per_token_if_dynamic=use_per_token_if_dynamic) + if input.dtype != torch.float8_e4m3fnuz: + qinput, x_scale = ops.scaled_fp8_quant( + input_2d, + input_scale, + num_token_padding=17, + use_per_token_if_dynamic=use_per_token_if_dynamic) + else: + qinput, x_scale = input_2d, input_scale per_tensor_weights = (weight_scale.numel() == 1) per_tensor_activations = (x_scale.numel() == 1) @@ -144,7 +151,7 @@ def apply_fp8_linear( # Fused GEMM_DQ output = torch._scaled_mm(qinput, weight, - out_dtype=input.dtype, + out_dtype=out_dtype, scale_a=x_scale, scale_b=weight_scale, bias=bias) diff --git a/vllm/model_executor/models/exaone.py b/vllm/model_executor/models/exaone.py index bc3295da7b60a..a76e9c5d70f6f 100644 --- a/vllm/model_executor/models/exaone.py +++ b/vllm/model_executor/models/exaone.py @@ -535,8 +535,9 @@ def load_weights(self, weights: Iterable[Tuple[str, param = params_dict[scale_name] weight_loader = getattr(param, "weight_loader", default_weight_loader) - loaded_weight = (loaded_weight if loaded_weight.dim() == 0 else - loaded_weight[0]) + if loaded_weight.shape: + # scalar shape is torch.Size([1]), not torch.Size([]) + loaded_weight = loaded_weight[0] weight_loader(param, loaded_weight) loaded_params.add(scale_name) continue diff --git a/vllm/model_executor/models/granite.py b/vllm/model_executor/models/granite.py index 543b4e2f5e286..b4ebe493cb55d 100644 --- a/vllm/model_executor/models/granite.py +++ b/vllm/model_executor/models/granite.py @@ -477,8 +477,9 @@ def load_weights(self, weights: Iterable[Tuple[str, param = params_dict[scale_name] weight_loader = getattr(param, "weight_loader", default_weight_loader) - loaded_weight = (loaded_weight if loaded_weight.dim() == 0 else - loaded_weight[0]) + if loaded_weight.shape: + # scalar shape is torch.Size([1]), not torch.Size([]) + loaded_weight = loaded_weight[0] weight_loader(param, loaded_weight) loaded_params.add(scale_name) continue diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py index e214c30f5d60b..4a3f81c9f6770 100644 --- a/vllm/model_executor/models/llama.py +++ b/vllm/model_executor/models/llama.py @@ -26,6 +26,7 @@ from torch import nn from transformers import LlamaConfig +from vllm import _custom_ops as ops from vllm.attention import Attention, AttentionMetadata from vllm.compilation.decorators import support_torch_compile from vllm.config import CacheConfig, VllmConfig @@ -37,6 +38,7 @@ RowParallelLinear) from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.layers.quantization.fp8 import Fp8Config from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( @@ -44,7 +46,9 @@ from vllm.model_executor.model_loader.weight_utils import ( default_weight_loader, maybe_remap_kv_scale_name) from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.platforms import current_platform from vllm.sequence import IntermediateTensors +from vllm.utils import is_navi from .interfaces import SupportsLoRA, SupportsPP from .utils import (AutoWeightsLoader, PPMissingLayer, extract_layer_index, @@ -79,14 +83,27 @@ def __init__( quant_config=quant_config, prefix=f"{prefix}.down_proj", ) + self.use_fp8 = (isinstance(quant_config, Fp8Config) + if current_platform.is_rocm() and not is_navi() else + False) if hidden_act != "silu": raise ValueError(f"Unsupported activation: {hidden_act}. " "Only silu is supported for now.") self.act_fn = SiluAndMul() def forward(self, x): - x, _ = self.gate_up_proj(x) - x = self.act_fn(x) + if current_platform.is_rocm() and x.shape[0] == 1 and x.shape[1] == 1: + out = torch.empty(x.shape[0], + self.gate_up_proj.weight.shape[0] // 2, + dtype=x.dtype, + device=x.device) + ops.LLMM_Silu(self.gate_up_proj.weight, x.view(-1, x.size(-1)), + out, 8) + x = out.view(x.shape[0], x.shape[1], out.shape[1]) + else: + x, _ = self.gate_up_proj(x) + x = self.act_fn( + x, self.down_proj.input_scale if self.use_fp8 else None) x, _ = self.down_proj(x) return x @@ -214,6 +231,9 @@ def __init__( ) -> None: super().__init__() self.hidden_size = config.hidden_size + self.use_fp8 = (isinstance(quant_config, Fp8Config) + if current_platform.is_rocm() and not is_navi() else + False) rope_theta = getattr(config, "rope_theta", 10000) rope_scaling = getattr(config, "rope_scaling", None) if rope_scaling is not None and getattr( @@ -268,20 +288,23 @@ def forward( residual: Optional[torch.Tensor], ) -> Tuple[torch.Tensor, torch.Tensor]: # Self Attention + scale = None if not self.use_fp8 else \ + self.self_attn.qkv_proj.input_scale if residual is None: residual = hidden_states - hidden_states = self.input_layernorm(hidden_states) + hidden_states = self.input_layernorm(hidden_states, None, scale) else: hidden_states, residual = self.input_layernorm( - hidden_states, residual) + hidden_states, residual, scale) hidden_states = self.self_attn(positions=positions, hidden_states=hidden_states, kv_cache=kv_cache, attn_metadata=attn_metadata) # Fully Connected + scale = None if not self.use_fp8 else self.mlp.gate_up_proj.input_scale hidden_states, residual = self.post_attention_layernorm( - hidden_states, residual) + hidden_states, residual, scale) hidden_states = self.mlp(hidden_states) return hidden_states, residual diff --git a/vllm/model_executor/models/solar.py b/vllm/model_executor/models/solar.py index e6d919f23c85d..de1b631d2b165 100644 --- a/vllm/model_executor/models/solar.py +++ b/vllm/model_executor/models/solar.py @@ -494,8 +494,9 @@ def load_weights(self, weights: Iterable[Tuple[str, param = params_dict[scale_name] weight_loader = getattr(param, "weight_loader", default_weight_loader) - loaded_weight = (loaded_weight if loaded_weight.dim() == 0 else - loaded_weight[0]) + if loaded_weight.shape: + # scalar shape is torch.Size([1]), not torch.Size([]) + loaded_weight = loaded_weight[0] weight_loader(param, loaded_weight) loaded_params.add(scale_name) continue diff --git a/vllm/utils.py b/vllm/utils.py index 15481fb06e08e..21504060f090e 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -1555,6 +1555,27 @@ def contains(self, key: object, *, strict: bool = False) -> bool: return any(cls in self.data for cls in key.mro()) +@lru_cache(maxsize=None) +def is_navi() -> bool: + from vllm.platforms import current_platform + if not current_platform.is_rocm() or not torch.cuda.is_available(): + return False + # All (visible) GPUs must be of the same type, + # otherwise FP8 results can't be guaranteed. + archName = torch.cuda.get_device_properties('cuda').gcnArchName + return archName is not None and "gfx1" in archName + + +@lru_cache(maxsize=None) +def is_navi3() -> bool: + from vllm.platforms import current_platform + if not current_platform.is_rocm() or not torch.cuda.is_available(): + return False + # All (visible) GPUs must be of the same type, + # otherwise FP8 results can't be guaranteed. + archName = torch.cuda.get_device_properties('cuda').gcnArchName + return archName is not None and "gfx11" in archName + def weak_ref_tensor(tensor: torch.Tensor) -> torch.Tensor: """ Create a weak reference to a tensor. diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index bf1a40d48a789..b38d7ca7e3d33 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -124,6 +124,8 @@ def from_broadcasted_tensor_dict( if attn_backend is not None: tensor_dict = _init_attn_metadata_from_tensor_dict( attn_backend, tensor_dict) + if "enable_kv_scales_calculation" in tensor_dict: + tensor_dict.pop("enable_kv_scales_calculation") return cls(**tensor_dict) # Exclude `async_callback` to be able to pickle this object diff --git a/vllm/worker/model_runner_base.py b/vllm/worker/model_runner_base.py index aef4bdcdd4bf9..8907cea078cca 100644 --- a/vllm/worker/model_runner_base.py +++ b/vllm/worker/model_runner_base.py @@ -47,7 +47,8 @@ def _init_attn_metadata_from_tensor_dict( # Extract the fields used to create AttentionMetadata. valid_attn_kwargs = {} for field in dataclasses.fields(attn_backend.get_metadata_cls()): - if field.name in tensor_dict: + if field.name in tensor_dict and field.name != \ + 'enable_kv_scales_calculation': valid_attn_kwargs[field.name] = tensor_dict.pop(field.name) attn_metadata = attn_backend.make_metadata(**valid_attn_kwargs)