Skip to content

Commit

Permalink
use fp8 cuda core gemm kernel when M<=4
Browse files Browse the repository at this point in the history
  • Loading branch information
zhink committed Nov 26, 2024
1 parent 85333aa commit 2f5e5ea
Show file tree
Hide file tree
Showing 5 changed files with 283 additions and 20 deletions.
191 changes: 191 additions & 0 deletions csrc/gpu/fp8_gemm_with_cutlass/fp8_fp8_half_cuda_core_gemm.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,191 @@
// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "fp8_fp8_half_cuda_core_gemm.h"
#include "cutlass/numeric_conversion.h"

template <typename InputType,
typename OutputType,
int32_t TILE_M,
int32_t TILE_N,
int32_t BLOCK_SIZE,
bool UseBias>
__global__ void cudaCoreGemm(InputType const* __restrict__ act,
InputType const* __restrict__ weight,
OutputType const* __restrict__ bias,
OutputType* __restrict__ output,
int32_t m,
int32_t n,
int32_t k,
float alpha) {
using VecType = int4;
static constexpr int32_t kStepK =
static_cast<int32_t>(128 / (8 * sizeof(InputType)));
static constexpr int32_t kTileK = kStepK * BLOCK_SIZE;
auto tileIdM = static_cast<int32_t>(blockIdx.x * TILE_M);
auto tileIdN = static_cast<int32_t>(blockIdx.y * TILE_N);
auto tid = static_cast<int32_t>(threadIdx.x);
float tile_a[kStepK], tile_w[TILE_N * kStepK];
float acc[TILE_M * TILE_N];

static_assert(kStepK % 4 == 0);
using Converter = cutlass::NumericArrayConverter<float, InputType, 4>;
using CvtSrcType = typename Converter::source_type;
using CvtResType = typename Converter::result_type;

static constexpr int32_t kCvtCount =
static_cast<int32_t>(sizeof(VecType) / sizeof(CvtSrcType));

#pragma unroll
for (int32_t i = 0; i < TILE_M * TILE_N; ++i) {
acc[i] = 0;
}
act += tileIdM * k;
weight += tileIdN * k;
output += tileIdM * n + tileIdN;
if constexpr (UseBias) {
bias += tileIdN;
}
for (int32_t idxK = tid * kStepK; idxK < k; idxK += kTileK) {
for (int32_t i = 0; i < TILE_N; ++i) {
auto tile_w_quantized =
reinterpret_cast<VecType const*>(weight + i * k + idxK)[0];
#pragma unroll
for (int32_t cvtIdx = 0; cvtIdx < kCvtCount; ++cvtIdx) {
reinterpret_cast<CvtResType*>(tile_w)[i * kCvtCount + cvtIdx] =
Converter::convert(
reinterpret_cast<CvtSrcType*>(&tile_w_quantized)[cvtIdx]);
}
}
#pragma unroll
for (int32_t i = 0; i < TILE_M; ++i) {
auto tile_a_quantized =
reinterpret_cast<VecType const*>(act + i * k + idxK)[0];
#pragma unroll
for (int32_t cvtIdx = 0; cvtIdx < kCvtCount; ++cvtIdx) {
reinterpret_cast<CvtResType*>(tile_a)[cvtIdx] = Converter::convert(
reinterpret_cast<CvtSrcType*>(&tile_a_quantized)[cvtIdx]);
}
#pragma unroll
for (int32_t j = 0; j < TILE_N; ++j) {
#pragma unroll
for (int32_t l = 0; l < kStepK; ++l) {
acc[i * TILE_N + j] =
fma(tile_a[l], tile_w[j * kStepK + l], acc[i * TILE_N + j]);
}
}
}
}

typedef cub::WarpReduce<float> WarpReduce;

static constexpr int32_t kWarpSize = 32;
static constexpr int32_t kWarpNum = BLOCK_SIZE / kWarpSize;
int32_t warpId = tid / kWarpSize, laneId = tid % kWarpSize;
__shared__ float shmem[TILE_M * TILE_N * kWarpNum];
__shared__ typename WarpReduce::TempStorage tempStorage[kWarpNum];
#pragma unroll
for (int32_t mi = 0; mi < TILE_M; ++mi) {
#pragma unroll
for (int32_t ni = 0; ni < TILE_N; ++ni) {
float val = WarpReduce(tempStorage[warpId]).Sum(acc[mi * TILE_N + ni]);
if (laneId == 0) {
shmem[mi * TILE_N + ni + warpId * TILE_M * TILE_N] = val;
}
}
}

__syncthreads();
for (int32_t ii = tid; ii < TILE_M * TILE_N; ii += BLOCK_SIZE) {
int32_t mid = ii / TILE_N, nid = ii % TILE_N;
float val = 0;
#pragma unroll
for (int32_t jj = 0; jj < kWarpNum; ++jj) {
val += shmem[jj * TILE_M * TILE_N + ii];
}

if constexpr (UseBias) {
output[mid * n + nid] = static_cast<OutputType>(val * alpha + (float)*(bias+nid)) ;
} else {
output[mid * n + nid] = static_cast<OutputType>(val * alpha);
}
}
}

template <typename InputType,
typename OutputType,
int32_t TILE_M,
int32_t TILE_N,
int32_t BLOCK_SIZE>
void cudaCoreGemmKernel(GemmParams const& params) {
dim3 block(BLOCK_SIZE);
dim3 grid(params.m / TILE_M, params.n / TILE_N);
// std::cout << "m" << params.m << " n" << params.n << " k " << params.k << std::endl;

if (params.bias != nullptr) {
cudaCoreGemm<InputType, OutputType, TILE_M, TILE_N, BLOCK_SIZE, true>
<<<grid, block, 0, params.stream>>>(
reinterpret_cast<InputType const*>(params.act),
reinterpret_cast<InputType const*>(params.weight),
reinterpret_cast<OutputType const*>(params.bias),
reinterpret_cast<OutputType*>(params.output),
params.m,
params.n,
params.k,
params.alpha);
} else {
cudaCoreGemm<InputType, OutputType, TILE_M, TILE_N, BLOCK_SIZE, false>
<<<grid, block, 0, params.stream>>>(
reinterpret_cast<InputType const*>(params.act),
reinterpret_cast<InputType const*>(params.weight),
reinterpret_cast<OutputType const*>(params.bias),
reinterpret_cast<OutputType*>(params.output),
params.m,
params.n,
params.k,
params.alpha);
}
}

template <typename InputType,
typename OutputType,
int TILE_M,
int TILE_N,
int BLOCK_SIZE>
bool cudaCoreGemmTemplateCaller(GemmParams const& params) {
constexpr int cudaCoreGemmTemplateMaxM = 16;
if (params.m == TILE_M) {
cudaCoreGemmKernel<InputType, OutputType, TILE_M, TILE_N, BLOCK_SIZE>(
params);
return true;
}
if constexpr (TILE_M < cudaCoreGemmTemplateMaxM) {
return cudaCoreGemmTemplateCaller<InputType,
OutputType,
TILE_M + 1,
TILE_N,
BLOCK_SIZE>(params);
}
return false;
}

template <typename InputType, typename OutputType>
bool cuda_core_gemm_launcher(GemmParams const& params) {
return cudaCoreGemmTemplateCaller<InputType, OutputType, 1, 2, 256>(params);
}

template bool cuda_core_gemm_launcher<__nv_fp8_e4m3, __nv_bfloat16>(GemmParams const&);
template bool cuda_core_gemm_launcher<__nv_fp8_e4m3, half>(GemmParams const&);
template bool cuda_core_gemm_launcher<__nv_fp8_e5m2, __nv_bfloat16>(GemmParams const&);
template bool cuda_core_gemm_launcher<__nv_fp8_e5m2, half>(GemmParams const&);
37 changes: 37 additions & 0 deletions csrc/gpu/fp8_gemm_with_cutlass/fp8_fp8_half_cuda_core_gemm.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

#include "fp8_common.h" // NOLINT

typedef struct {
void const* act;
void const* weight;
void const* bias;
void* output;
int32_t m, n, k;
float alpha;
cudaStream_t stream;
} GemmParams;

inline bool enable_cuda_core_fp8_gemm() {
static const char* enable_cuda_core_fp8_env = std::getenv("FLAGS_cuda_core_fp8_gemm");
static const bool enable_cuda_core_fp8_gemm =
enable_cuda_core_fp8_env != nullptr && std::string(enable_cuda_core_fp8_env) == "1";
return enable_cuda_core_fp8_gemm;
}

template <typename InputType, typename OutputType>
bool cuda_core_gemm_launcher(GemmParams const& params);
70 changes: 50 additions & 20 deletions csrc/gpu/fp8_gemm_with_cutlass/fp8_fp8_half_gemm.cu
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

#include "fp8_gemm_fused/fp8_fp8_gemm_scale_bias_act.h"
#include "fp8_common.h" // NOLINT
#include "fp8_fp8_half_cuda_core_gemm.h"

std::vector<paddle::Tensor> cutlass_fp8_fp8_half_gemm(
const paddle::Tensor& x,
Expand Down Expand Up @@ -116,26 +117,55 @@ std::vector<paddle::Tensor> cutlass_fp8_fp8_half_gemm(
}
}

GemmEpilogueAllParams params = {
x_ptr,
y_ptr,
out_ptr,
scale,
M,
N,
K,
lda,
ldb,
ldd,
batch_count,
place,
stream,
sm_version,
0.01, // for leaky_relu
bias_data,
bias_dims,
fuse_gemm_config};
fp8_fp8_gemm_scale_bias_act(params);
if (M <=4 && trans_y && !trans_x && act == "noact" && enable_cuda_core_fp8_gemm()) {
GemmParams params = {
x_ptr,
y_ptr,
bias_data,
out_ptr,
M,
N,
K,
scale,
stream,
};

if (x.dtype() == phi::DataType::FLOAT8_E4M3FN)
{
if(output_dtype == "bfloat16") {
cuda_core_gemm_launcher<__nv_fp8_e4m3, __nv_bfloat16>(params);

} else {
cuda_core_gemm_launcher<__nv_fp8_e4m3, half>(params);
}
} else {
if(output_dtype == "bfloat16") {
cuda_core_gemm_launcher<__nv_fp8_e5m2, __nv_bfloat16>(params);
} else {
cuda_core_gemm_launcher<__nv_fp8_e5m2, half>(params);
}
}
} else {
GemmEpilogueAllParams params = {x_ptr,
y_ptr,
out_ptr,
scale,
M,
N,
K,
lda,
ldb,
ldd,
batch_count,
place,
stream,
sm_version,
0.01, // for leaky_relu
bias_data,
bias_dims,
fuse_gemm_config};
fp8_fp8_gemm_scale_bias_act(params);
}
return {out};
}

Expand Down
1 change: 1 addition & 0 deletions csrc/setup_cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,7 @@ def get_gencode_flags():
sources += find_end_files("gpu/cutlass_kernels/fp8_gemm_fused/autogen", ".cu")
sources += [
"gpu/fp8_gemm_with_cutlass/fp8_fp8_half_gemm.cu",
"gpu/fp8_gemm_with_cutlass/fp8_fp8_half_cuda_core_gemm.cu",
"gpu/fp8_gemm_with_cutlass/fp8_fp8_fp8_dual_gemm.cu",
]

Expand Down
4 changes: 4 additions & 0 deletions llm/docs/predict/best_practices.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,10 @@ PaddleNLP 提供了多种环境变量,用于优化推理性能和资源使用

- `FLAGS_cublaslt_device_best_config`:在 FLAGS_enable_blaslt_global_search 设为1的前提下,使用`FLAGS_cublaslt_device_best_config`来指定离线调优出的 int8 gemm 配置文件,默认值为""。配置文件可以通过`PaddleNLP/csrc/utils/tune_cublaslt_int8_gemm.py`产出,该脚本会自动搜索当前输入大小下 cuBLASLt 提供的最优 gemm 配置并将结果记录下来,需要注意的是不同的 CUDA 版本需要分别 tune。推理 A8W8模型并且 FLAGS_enable_blaslt_global_search 设为1时使用此 FLAG 会获得更优的性能。

- `FLAGS_cuda_core_int8_gemm`:是否开启小 Batch Int8 Gemm优化,默认值不开启。设为1可开启,推理A8W8模型时,平均性能会加速约40%-55%,适用于SM>=70的显卡。

- `FLAGS_cuda_core_fp8_gemm`:是否开启小 Batch FP8 Gemm优化,默认值不开启。设为1可开启,推理 FP8模型时,平均性能会加速约30%左右,适用于SM>=89的显卡。

**GQA 优化**

- `FLAGS_use_xqa_optim`:gpa 是否开启 xqa 优化,默认值为0,表示不开启。gqa 模型(如 llama3/3.1、qwen2)设为1性能会更好。
Expand Down

0 comments on commit 2f5e5ea

Please sign in to comment.