From 2f5e5eae9a408c62bc07b3f9b15cd5cc30620087 Mon Sep 17 00:00:00 2001 From: zhink <771809832@qq.com> Date: Thu, 14 Nov 2024 14:25:16 +0800 Subject: [PATCH] use fp8 cuda core gemm kernel when M<=4 --- .../fp8_fp8_half_cuda_core_gemm.cu | 191 ++++++++++++++++++ .../fp8_fp8_half_cuda_core_gemm.h | 37 ++++ .../fp8_fp8_half_gemm.cu | 70 +++++-- csrc/setup_cuda.py | 1 + llm/docs/predict/best_practices.md | 4 + 5 files changed, 283 insertions(+), 20 deletions(-) create mode 100644 csrc/gpu/fp8_gemm_with_cutlass/fp8_fp8_half_cuda_core_gemm.cu create mode 100644 csrc/gpu/fp8_gemm_with_cutlass/fp8_fp8_half_cuda_core_gemm.h diff --git a/csrc/gpu/fp8_gemm_with_cutlass/fp8_fp8_half_cuda_core_gemm.cu b/csrc/gpu/fp8_gemm_with_cutlass/fp8_fp8_half_cuda_core_gemm.cu new file mode 100644 index 000000000000..06295cd62207 --- /dev/null +++ b/csrc/gpu/fp8_gemm_with_cutlass/fp8_fp8_half_cuda_core_gemm.cu @@ -0,0 +1,191 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "fp8_fp8_half_cuda_core_gemm.h" +#include "cutlass/numeric_conversion.h" + +template +__global__ void cudaCoreGemm(InputType const* __restrict__ act, + InputType const* __restrict__ weight, + OutputType const* __restrict__ bias, + OutputType* __restrict__ output, + int32_t m, + int32_t n, + int32_t k, + float alpha) { + using VecType = int4; + static constexpr int32_t kStepK = + static_cast(128 / (8 * sizeof(InputType))); + static constexpr int32_t kTileK = kStepK * BLOCK_SIZE; + auto tileIdM = static_cast(blockIdx.x * TILE_M); + auto tileIdN = static_cast(blockIdx.y * TILE_N); + auto tid = static_cast(threadIdx.x); + float tile_a[kStepK], tile_w[TILE_N * kStepK]; + float acc[TILE_M * TILE_N]; + + static_assert(kStepK % 4 == 0); + using Converter = cutlass::NumericArrayConverter; + using CvtSrcType = typename Converter::source_type; + using CvtResType = typename Converter::result_type; + + static constexpr int32_t kCvtCount = + static_cast(sizeof(VecType) / sizeof(CvtSrcType)); + +#pragma unroll + for (int32_t i = 0; i < TILE_M * TILE_N; ++i) { + acc[i] = 0; + } + act += tileIdM * k; + weight += tileIdN * k; + output += tileIdM * n + tileIdN; + if constexpr (UseBias) { + bias += tileIdN; + } + for (int32_t idxK = tid * kStepK; idxK < k; idxK += kTileK) { + for (int32_t i = 0; i < TILE_N; ++i) { + auto tile_w_quantized = + reinterpret_cast(weight + i * k + idxK)[0]; +#pragma unroll + for (int32_t cvtIdx = 0; cvtIdx < kCvtCount; ++cvtIdx) { + reinterpret_cast(tile_w)[i * kCvtCount + cvtIdx] = + Converter::convert( + reinterpret_cast(&tile_w_quantized)[cvtIdx]); + } + } +#pragma unroll + for (int32_t i = 0; i < TILE_M; ++i) { + auto tile_a_quantized = + reinterpret_cast(act + i * k + idxK)[0]; +#pragma unroll + for (int32_t cvtIdx = 0; cvtIdx < kCvtCount; ++cvtIdx) { + reinterpret_cast(tile_a)[cvtIdx] = Converter::convert( + reinterpret_cast(&tile_a_quantized)[cvtIdx]); + } +#pragma unroll + for (int32_t j = 0; j < TILE_N; ++j) { +#pragma unroll + for (int32_t l = 0; l < kStepK; ++l) { + acc[i * TILE_N + j] = + fma(tile_a[l], tile_w[j * kStepK + l], acc[i * TILE_N + j]); + } + } + } + } + + typedef cub::WarpReduce WarpReduce; + + static constexpr int32_t kWarpSize = 32; + static constexpr int32_t kWarpNum = BLOCK_SIZE / kWarpSize; + int32_t warpId = tid / kWarpSize, laneId = tid % kWarpSize; + __shared__ float shmem[TILE_M * TILE_N * kWarpNum]; + __shared__ typename WarpReduce::TempStorage tempStorage[kWarpNum]; +#pragma unroll + for (int32_t mi = 0; mi < TILE_M; ++mi) { +#pragma unroll + for (int32_t ni = 0; ni < TILE_N; ++ni) { + float val = WarpReduce(tempStorage[warpId]).Sum(acc[mi * TILE_N + ni]); + if (laneId == 0) { + shmem[mi * TILE_N + ni + warpId * TILE_M * TILE_N] = val; + } + } + } + + __syncthreads(); + for (int32_t ii = tid; ii < TILE_M * TILE_N; ii += BLOCK_SIZE) { + int32_t mid = ii / TILE_N, nid = ii % TILE_N; + float val = 0; +#pragma unroll + for (int32_t jj = 0; jj < kWarpNum; ++jj) { + val += shmem[jj * TILE_M * TILE_N + ii]; + } + + if constexpr (UseBias) { + output[mid * n + nid] = static_cast(val * alpha + (float)*(bias+nid)) ; + } else { + output[mid * n + nid] = static_cast(val * alpha); + } + } +} + +template +void cudaCoreGemmKernel(GemmParams const& params) { + dim3 block(BLOCK_SIZE); + dim3 grid(params.m / TILE_M, params.n / TILE_N); + // std::cout << "m" << params.m << " n" << params.n << " k " << params.k << std::endl; + + if (params.bias != nullptr) { + cudaCoreGemm + <<>>( + reinterpret_cast(params.act), + reinterpret_cast(params.weight), + reinterpret_cast(params.bias), + reinterpret_cast(params.output), + params.m, + params.n, + params.k, + params.alpha); + } else { + cudaCoreGemm + <<>>( + reinterpret_cast(params.act), + reinterpret_cast(params.weight), + reinterpret_cast(params.bias), + reinterpret_cast(params.output), + params.m, + params.n, + params.k, + params.alpha); + } +} + +template +bool cudaCoreGemmTemplateCaller(GemmParams const& params) { + constexpr int cudaCoreGemmTemplateMaxM = 16; + if (params.m == TILE_M) { + cudaCoreGemmKernel( + params); + return true; + } + if constexpr (TILE_M < cudaCoreGemmTemplateMaxM) { + return cudaCoreGemmTemplateCaller(params); + } + return false; +} + +template +bool cuda_core_gemm_launcher(GemmParams const& params) { + return cudaCoreGemmTemplateCaller(params); +} + +template bool cuda_core_gemm_launcher<__nv_fp8_e4m3, __nv_bfloat16>(GemmParams const&); +template bool cuda_core_gemm_launcher<__nv_fp8_e4m3, half>(GemmParams const&); +template bool cuda_core_gemm_launcher<__nv_fp8_e5m2, __nv_bfloat16>(GemmParams const&); +template bool cuda_core_gemm_launcher<__nv_fp8_e5m2, half>(GemmParams const&); \ No newline at end of file diff --git a/csrc/gpu/fp8_gemm_with_cutlass/fp8_fp8_half_cuda_core_gemm.h b/csrc/gpu/fp8_gemm_with_cutlass/fp8_fp8_half_cuda_core_gemm.h new file mode 100644 index 000000000000..31eab1943dac --- /dev/null +++ b/csrc/gpu/fp8_gemm_with_cutlass/fp8_fp8_half_cuda_core_gemm.h @@ -0,0 +1,37 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "fp8_common.h" // NOLINT + +typedef struct { + void const* act; + void const* weight; + void const* bias; + void* output; + int32_t m, n, k; + float alpha; + cudaStream_t stream; +} GemmParams; + +inline bool enable_cuda_core_fp8_gemm() { + static const char* enable_cuda_core_fp8_env = std::getenv("FLAGS_cuda_core_fp8_gemm"); + static const bool enable_cuda_core_fp8_gemm = + enable_cuda_core_fp8_env != nullptr && std::string(enable_cuda_core_fp8_env) == "1"; + return enable_cuda_core_fp8_gemm; +} + +template +bool cuda_core_gemm_launcher(GemmParams const& params); diff --git a/csrc/gpu/fp8_gemm_with_cutlass/fp8_fp8_half_gemm.cu b/csrc/gpu/fp8_gemm_with_cutlass/fp8_fp8_half_gemm.cu index 3e42b065fbc9..ccec03803bda 100644 --- a/csrc/gpu/fp8_gemm_with_cutlass/fp8_fp8_half_gemm.cu +++ b/csrc/gpu/fp8_gemm_with_cutlass/fp8_fp8_half_gemm.cu @@ -16,6 +16,7 @@ #include "fp8_gemm_fused/fp8_fp8_gemm_scale_bias_act.h" #include "fp8_common.h" // NOLINT +#include "fp8_fp8_half_cuda_core_gemm.h" std::vector cutlass_fp8_fp8_half_gemm( const paddle::Tensor& x, @@ -116,26 +117,55 @@ std::vector cutlass_fp8_fp8_half_gemm( } } - GemmEpilogueAllParams params = { - x_ptr, - y_ptr, - out_ptr, - scale, - M, - N, - K, - lda, - ldb, - ldd, - batch_count, - place, - stream, - sm_version, - 0.01, // for leaky_relu - bias_data, - bias_dims, - fuse_gemm_config}; - fp8_fp8_gemm_scale_bias_act(params); + if (M <=4 && trans_y && !trans_x && act == "noact" && enable_cuda_core_fp8_gemm()) { + GemmParams params = { + x_ptr, + y_ptr, + bias_data, + out_ptr, + M, + N, + K, + scale, + stream, + }; + + if (x.dtype() == phi::DataType::FLOAT8_E4M3FN) + { + if(output_dtype == "bfloat16") { + cuda_core_gemm_launcher<__nv_fp8_e4m3, __nv_bfloat16>(params); + + } else { + cuda_core_gemm_launcher<__nv_fp8_e4m3, half>(params); + } + } else { + if(output_dtype == "bfloat16") { + cuda_core_gemm_launcher<__nv_fp8_e5m2, __nv_bfloat16>(params); + } else { + cuda_core_gemm_launcher<__nv_fp8_e5m2, half>(params); + } + } + } else { + GemmEpilogueAllParams params = {x_ptr, + y_ptr, + out_ptr, + scale, + M, + N, + K, + lda, + ldb, + ldd, + batch_count, + place, + stream, + sm_version, + 0.01, // for leaky_relu + bias_data, + bias_dims, + fuse_gemm_config}; + fp8_fp8_gemm_scale_bias_act(params); + } return {out}; } diff --git a/csrc/setup_cuda.py b/csrc/setup_cuda.py index 3978f8c17de8..58acdb3d25d3 100644 --- a/csrc/setup_cuda.py +++ b/csrc/setup_cuda.py @@ -159,6 +159,7 @@ def get_gencode_flags(): sources += find_end_files("gpu/cutlass_kernels/fp8_gemm_fused/autogen", ".cu") sources += [ "gpu/fp8_gemm_with_cutlass/fp8_fp8_half_gemm.cu", + "gpu/fp8_gemm_with_cutlass/fp8_fp8_half_cuda_core_gemm.cu", "gpu/fp8_gemm_with_cutlass/fp8_fp8_fp8_dual_gemm.cu", ] diff --git a/llm/docs/predict/best_practices.md b/llm/docs/predict/best_practices.md index 0d469edd8a82..450a2f59fa53 100644 --- a/llm/docs/predict/best_practices.md +++ b/llm/docs/predict/best_practices.md @@ -9,6 +9,10 @@ PaddleNLP 提供了多种环境变量,用于优化推理性能和资源使用 - `FLAGS_cublaslt_device_best_config`:在 FLAGS_enable_blaslt_global_search 设为1的前提下,使用`FLAGS_cublaslt_device_best_config`来指定离线调优出的 int8 gemm 配置文件,默认值为""。配置文件可以通过`PaddleNLP/csrc/utils/tune_cublaslt_int8_gemm.py`产出,该脚本会自动搜索当前输入大小下 cuBLASLt 提供的最优 gemm 配置并将结果记录下来,需要注意的是不同的 CUDA 版本需要分别 tune。推理 A8W8模型并且 FLAGS_enable_blaslt_global_search 设为1时使用此 FLAG 会获得更优的性能。 +- `FLAGS_cuda_core_int8_gemm`:是否开启小 Batch Int8 Gemm优化,默认值不开启。设为1可开启,推理A8W8模型时,平均性能会加速约40%-55%,适用于SM>=70的显卡。 + +- `FLAGS_cuda_core_fp8_gemm`:是否开启小 Batch FP8 Gemm优化,默认值不开启。设为1可开启,推理 FP8模型时,平均性能会加速约30%左右,适用于SM>=89的显卡。 + **GQA 优化** - `FLAGS_use_xqa_optim`:gpa 是否开启 xqa 优化,默认值为0,表示不开启。gqa 模型(如 llama3/3.1、qwen2)设为1性能会更好。