Skip to content

support per_tensor_quant and per_token_group_quant #10359

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
201 changes: 0 additions & 201 deletions csrc/gpu/group_quant.cu

This file was deleted.

39 changes: 39 additions & 0 deletions csrc/gpu/helper.h
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,13 @@ class PDTraits<paddle::DataType::FLOAT8_E4M3FN> {
typedef paddle::float8_e4m3fn data_t;
};

template <>
class PDTraits<paddle::DataType::INT8> {
public:
typedef int8_t DataType;
typedef int8_t data_t;
};

template <typename T, int Size>
struct alignas(sizeof(T) * Size) AlignedVector {
T val[Size];
Expand Down Expand Up @@ -245,3 +252,35 @@ inline bool GetMlaUseTensorcore() {
const bool mla_use_tensorcore = flags_mla_use_tensorcore && enable_mla_tensorcore;
return mla_use_tensorcore;
}

__device__ __forceinline__ float atomicMaxFloat(float* addr, float value) {
float old;
old = (value >= 0) ? __int_as_float(atomicMax((int*)addr, __float_as_int(value)))
: __uint_as_float(atomicMin((unsigned int*)addr, __float_as_uint(value)));
return old;
}

__device__ __forceinline__ float warpReduceMax(float max_value) {
max_value = fmaxf(max_value, __shfl_xor_sync(0xffffffff, max_value, 16));
max_value = fmaxf(max_value, __shfl_xor_sync(0xffffffff, max_value, 8));
max_value = fmaxf(max_value, __shfl_xor_sync(0xffffffff, max_value, 4));
max_value = fmaxf(max_value, __shfl_xor_sync(0xffffffff, max_value, 2));
max_value = fmaxf(max_value, __shfl_xor_sync(0xffffffff, max_value, 1));
return max_value;
}

__device__ __forceinline__ float blockReduceMax(float max_value) {
static __shared__ float warpLevelMaxs[32];
const int laneId = threadIdx.x & 0x1f;;
const int warpId = threadIdx.x >> 5;

max_value = warpReduceMax(max_value);

if (laneId == 0) warpLevelMaxs[warpId] = max_value;
__syncthreads();

max_value = (threadIdx.x < blockDim.x / 32) ? warpLevelMaxs[laneId] : 0;
if (warpId == 0) max_value = warpReduceMax(max_value);

return max_value;
}
Loading
Loading