Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Inference] use fp8 cuda core gemm kernel when M<=4 #9423

Merged
merged 1 commit into from
Nov 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
191 changes: 191 additions & 0 deletions csrc/gpu/fp8_gemm_with_cutlass/fp8_fp8_half_cuda_core_gemm.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,191 @@
// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "fp8_fp8_half_cuda_core_gemm.h"
#include "cutlass/numeric_conversion.h"

template <typename InputType,
typename OutputType,
int32_t TILE_M,
int32_t TILE_N,
int32_t BLOCK_SIZE,
bool UseBias>
__global__ void cudaCoreGemm(InputType const* __restrict__ act,
InputType const* __restrict__ weight,
OutputType const* __restrict__ bias,
OutputType* __restrict__ output,
int32_t m,
int32_t n,
int32_t k,
float alpha) {
using VecType = int4;
static constexpr int32_t kStepK =
static_cast<int32_t>(128 / (8 * sizeof(InputType)));
static constexpr int32_t kTileK = kStepK * BLOCK_SIZE;
auto tileIdM = static_cast<int32_t>(blockIdx.x * TILE_M);
auto tileIdN = static_cast<int32_t>(blockIdx.y * TILE_N);
auto tid = static_cast<int32_t>(threadIdx.x);
float tile_a[kStepK], tile_w[TILE_N * kStepK];
float acc[TILE_M * TILE_N];

static_assert(kStepK % 4 == 0);
using Converter = cutlass::NumericArrayConverter<float, InputType, 4>;
using CvtSrcType = typename Converter::source_type;
using CvtResType = typename Converter::result_type;

static constexpr int32_t kCvtCount =
static_cast<int32_t>(sizeof(VecType) / sizeof(CvtSrcType));

#pragma unroll
for (int32_t i = 0; i < TILE_M * TILE_N; ++i) {
acc[i] = 0;
}
act += tileIdM * k;
weight += tileIdN * k;
output += tileIdM * n + tileIdN;
if constexpr (UseBias) {
bias += tileIdN;
}
for (int32_t idxK = tid * kStepK; idxK < k; idxK += kTileK) {
for (int32_t i = 0; i < TILE_N; ++i) {
auto tile_w_quantized =
reinterpret_cast<VecType const*>(weight + i * k + idxK)[0];
#pragma unroll
for (int32_t cvtIdx = 0; cvtIdx < kCvtCount; ++cvtIdx) {
reinterpret_cast<CvtResType*>(tile_w)[i * kCvtCount + cvtIdx] =
Converter::convert(
reinterpret_cast<CvtSrcType*>(&tile_w_quantized)[cvtIdx]);
}
}
#pragma unroll
for (int32_t i = 0; i < TILE_M; ++i) {
auto tile_a_quantized =
reinterpret_cast<VecType const*>(act + i * k + idxK)[0];
#pragma unroll
for (int32_t cvtIdx = 0; cvtIdx < kCvtCount; ++cvtIdx) {
reinterpret_cast<CvtResType*>(tile_a)[cvtIdx] = Converter::convert(
reinterpret_cast<CvtSrcType*>(&tile_a_quantized)[cvtIdx]);
}
#pragma unroll
for (int32_t j = 0; j < TILE_N; ++j) {
#pragma unroll
for (int32_t l = 0; l < kStepK; ++l) {
acc[i * TILE_N + j] =
fma(tile_a[l], tile_w[j * kStepK + l], acc[i * TILE_N + j]);
}
}
}
}

typedef cub::WarpReduce<float> WarpReduce;

static constexpr int32_t kWarpSize = 32;
static constexpr int32_t kWarpNum = BLOCK_SIZE / kWarpSize;
int32_t warpId = tid / kWarpSize, laneId = tid % kWarpSize;
__shared__ float shmem[TILE_M * TILE_N * kWarpNum];
__shared__ typename WarpReduce::TempStorage tempStorage[kWarpNum];
#pragma unroll
for (int32_t mi = 0; mi < TILE_M; ++mi) {
#pragma unroll
for (int32_t ni = 0; ni < TILE_N; ++ni) {
float val = WarpReduce(tempStorage[warpId]).Sum(acc[mi * TILE_N + ni]);
if (laneId == 0) {
shmem[mi * TILE_N + ni + warpId * TILE_M * TILE_N] = val;
}
}
}

__syncthreads();
for (int32_t ii = tid; ii < TILE_M * TILE_N; ii += BLOCK_SIZE) {
int32_t mid = ii / TILE_N, nid = ii % TILE_N;
float val = 0;
#pragma unroll
for (int32_t jj = 0; jj < kWarpNum; ++jj) {
val += shmem[jj * TILE_M * TILE_N + ii];
}

if constexpr (UseBias) {
output[mid * n + nid] = static_cast<OutputType>(val * alpha + (float)*(bias+nid)) ;
} else {
output[mid * n + nid] = static_cast<OutputType>(val * alpha);
}
}
}

template <typename InputType,
typename OutputType,
int32_t TILE_M,
int32_t TILE_N,
int32_t BLOCK_SIZE>
void cudaCoreGemmKernel(GemmParams const& params) {
dim3 block(BLOCK_SIZE);
dim3 grid(params.m / TILE_M, params.n / TILE_N);
// std::cout << "m" << params.m << " n" << params.n << " k " << params.k << std::endl;

if (params.bias != nullptr) {
cudaCoreGemm<InputType, OutputType, TILE_M, TILE_N, BLOCK_SIZE, true>
<<<grid, block, 0, params.stream>>>(
reinterpret_cast<InputType const*>(params.act),
reinterpret_cast<InputType const*>(params.weight),
reinterpret_cast<OutputType const*>(params.bias),
reinterpret_cast<OutputType*>(params.output),
params.m,
params.n,
params.k,
params.alpha);
} else {
cudaCoreGemm<InputType, OutputType, TILE_M, TILE_N, BLOCK_SIZE, false>
<<<grid, block, 0, params.stream>>>(
reinterpret_cast<InputType const*>(params.act),
reinterpret_cast<InputType const*>(params.weight),
reinterpret_cast<OutputType const*>(params.bias),
reinterpret_cast<OutputType*>(params.output),
params.m,
params.n,
params.k,
params.alpha);
}
}

template <typename InputType,
typename OutputType,
int TILE_M,
int TILE_N,
int BLOCK_SIZE>
bool cudaCoreGemmTemplateCaller(GemmParams const& params) {
constexpr int cudaCoreGemmTemplateMaxM = 16;
if (params.m == TILE_M) {
cudaCoreGemmKernel<InputType, OutputType, TILE_M, TILE_N, BLOCK_SIZE>(
params);
return true;
}
if constexpr (TILE_M < cudaCoreGemmTemplateMaxM) {
return cudaCoreGemmTemplateCaller<InputType,
OutputType,
TILE_M + 1,
TILE_N,
BLOCK_SIZE>(params);
}
return false;
}

template <typename InputType, typename OutputType>
bool cuda_core_gemm_launcher(GemmParams const& params) {
return cudaCoreGemmTemplateCaller<InputType, OutputType, 1, 2, 256>(params);
}

template bool cuda_core_gemm_launcher<__nv_fp8_e4m3, __nv_bfloat16>(GemmParams const&);
template bool cuda_core_gemm_launcher<__nv_fp8_e4m3, half>(GemmParams const&);
template bool cuda_core_gemm_launcher<__nv_fp8_e5m2, __nv_bfloat16>(GemmParams const&);
template bool cuda_core_gemm_launcher<__nv_fp8_e5m2, half>(GemmParams const&);
37 changes: 37 additions & 0 deletions csrc/gpu/fp8_gemm_with_cutlass/fp8_fp8_half_cuda_core_gemm.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

#include "fp8_common.h" // NOLINT

typedef struct {
void const* act;
void const* weight;
void const* bias;
void* output;
int32_t m, n, k;
float alpha;
cudaStream_t stream;
} GemmParams;

inline bool enable_cuda_core_fp8_gemm() {
static const char* enable_cuda_core_fp8_env = std::getenv("FLAGS_cuda_core_fp8_gemm");
static const bool enable_cuda_core_fp8_gemm =
enable_cuda_core_fp8_env != nullptr && std::string(enable_cuda_core_fp8_env) == "1";
return enable_cuda_core_fp8_gemm;
}

template <typename InputType, typename OutputType>
bool cuda_core_gemm_launcher(GemmParams const& params);
70 changes: 50 additions & 20 deletions csrc/gpu/fp8_gemm_with_cutlass/fp8_fp8_half_gemm.cu
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

#include "fp8_gemm_fused/fp8_fp8_gemm_scale_bias_act.h"
#include "fp8_common.h" // NOLINT
#include "fp8_fp8_half_cuda_core_gemm.h"

std::vector<paddle::Tensor> cutlass_fp8_fp8_half_gemm(
const paddle::Tensor& x,
Expand Down Expand Up @@ -116,26 +117,55 @@ std::vector<paddle::Tensor> cutlass_fp8_fp8_half_gemm(
}
}

GemmEpilogueAllParams params = {
x_ptr,
y_ptr,
out_ptr,
scale,
M,
N,
K,
lda,
ldb,
ldd,
batch_count,
place,
stream,
sm_version,
0.01, // for leaky_relu
bias_data,
bias_dims,
fuse_gemm_config};
fp8_fp8_gemm_scale_bias_act(params);
if (M <=4 && trans_y && !trans_x && act == "noact" && enable_cuda_core_fp8_gemm()) {
GemmParams params = {
x_ptr,
y_ptr,
bias_data,
out_ptr,
M,
N,
K,
scale,
stream,
};

if (x.dtype() == phi::DataType::FLOAT8_E4M3FN)
{
if(output_dtype == "bfloat16") {
cuda_core_gemm_launcher<__nv_fp8_e4m3, __nv_bfloat16>(params);

} else {
cuda_core_gemm_launcher<__nv_fp8_e4m3, half>(params);
}
} else {
if(output_dtype == "bfloat16") {
cuda_core_gemm_launcher<__nv_fp8_e5m2, __nv_bfloat16>(params);
} else {
cuda_core_gemm_launcher<__nv_fp8_e5m2, half>(params);
}
}
} else {
GemmEpilogueAllParams params = {x_ptr,
y_ptr,
out_ptr,
scale,
M,
N,
K,
lda,
ldb,
ldd,
batch_count,
place,
stream,
sm_version,
0.01, // for leaky_relu
bias_data,
bias_dims,
fuse_gemm_config};
fp8_fp8_gemm_scale_bias_act(params);
}
return {out};
}

Expand Down
1 change: 1 addition & 0 deletions csrc/setup_cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,7 @@ def get_gencode_flags():
sources += find_end_files("gpu/cutlass_kernels/fp8_gemm_fused/autogen", ".cu")
sources += [
"gpu/fp8_gemm_with_cutlass/fp8_fp8_half_gemm.cu",
"gpu/fp8_gemm_with_cutlass/fp8_fp8_half_cuda_core_gemm.cu",
"gpu/fp8_gemm_with_cutlass/fp8_fp8_fp8_dual_gemm.cu",
]

Expand Down
4 changes: 4 additions & 0 deletions llm/docs/predict/best_practices.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,10 @@ PaddleNLP 提供了多种环境变量,用于优化推理性能和资源使用

- `FLAGS_cublaslt_device_best_config`:在 FLAGS_enable_blaslt_global_search 设为1的前提下,使用`FLAGS_cublaslt_device_best_config`来指定离线调优出的 int8 gemm 配置文件,默认值为""。配置文件可以通过`PaddleNLP/csrc/utils/tune_cublaslt_int8_gemm.py`产出,该脚本会自动搜索当前输入大小下 cuBLASLt 提供的最优 gemm 配置并将结果记录下来,需要注意的是不同的 CUDA 版本需要分别 tune。推理 A8W8模型并且 FLAGS_enable_blaslt_global_search 设为1时使用此 FLAG 会获得更优的性能。

- `FLAGS_cuda_core_int8_gemm`:是否开启小 Batch Int8 Gemm优化,默认值不开启。设为1可开启,推理A8W8模型时,平均性能会加速约40%-55%,适用于SM>=70的显卡。

- `FLAGS_cuda_core_fp8_gemm`:是否开启小 Batch FP8 Gemm优化,默认值不开启。设为1可开启,推理 FP8模型时,平均性能会加速约30%左右,适用于SM>=89的显卡。

**GQA 优化**

- `FLAGS_use_xqa_optim`:gpa 是否开启 xqa 优化,默认值为0,表示不开启。gqa 模型(如 llama3/3.1、qwen2)设为1性能会更好。
Expand Down
Loading