Skip to content
This repository has been archived by the owner on Aug 30, 2024. It is now read-only.

Commit

Permalink
Squashed commit of the following:
Browse files Browse the repository at this point in the history
commit 96d2966
Author: Ding, Yi1 <[email protected]>
Date:   Wed Jul 10 08:26:45 2024 +0000

    s4=>i4

commit 696820f
Author: Ding, Yi1 <[email protected]>
Date:   Wed Jul 10 05:50:03 2024 +0000

    add back dtype_zero_pt checkk

commit 20da116
Author: Wang,Zhe <[email protected]>
Date:   Mon Jul 1 10:51:38 2024 +0800

    support bf16 activation

commit ea42a9a
Author: Zhe, Wang <[email protected]>
Date:   Tue Jun 18 09:23:19 2024 +0800

    support fp_zp quant
  • Loading branch information
zhewang1-intc authored and DDEle committed Jul 11, 2024
1 parent 22c2123 commit 074e76d
Show file tree
Hide file tree
Showing 5 changed files with 105 additions and 34 deletions.
6 changes: 5 additions & 1 deletion include/common/core/common_types.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,11 @@ enum class grf_mode : uint8_t { normal = 0, double_grf = 1 };

enum class mem_layout : uint8_t { row_major = 0, col_major = 1 };

enum class quant_mode : uint8_t { I4_ASYM = 0, I4_SYM = 1 };
enum class quant_mode : uint8_t {
I4_ASYM = 0,
I4_SYM = 1,
I4_ASYM_FP_ZERO = 2
};

struct quant_info {
quant_mode quant_mode;
Expand Down
43 changes: 30 additions & 13 deletions include/experimental/group/gemm/impl/int4_dequantize_xe.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -102,10 +102,16 @@ class gemm_t<
std::is_same<remove_const_t<dtype_b>, remove_const_t<int4x8>>::value,
"this is for 4bit matB ");
static_assert(
std::is_same<remove_const_t<dtype_zero_pt>, remove_const_t<int4x2>>::
value ||
std::is_same<remove_const_t<dtype_zero_pt>, remove_const_t<int4x8>>::
value,
quant_info_.quant_mode == quant_mode::I4_ASYM_FP_ZERO
? std::is_same_v<
remove_const_t<dtype_zero_pt>,
remove_const_t<dtype_a>>
: (std::is_same_v<
remove_const_t<dtype_zero_pt>,
remove_const_t<int4x2>> ||
std::is_same_v<
remove_const_t<dtype_zero_pt>,
remove_const_t<int4x8>>),
"this is for 4bit zero_pt ");

/******** set memory attribute **********/
Expand Down Expand Up @@ -284,12 +290,20 @@ class gemm_t<
arch_tag>;

// compress int4 along N dimensions
using zero_pt_tile_desc_t = subgroup::tile_desc_t<
(tile_size_x_b + pack_ratio - 1) / pack_ratio,
tile_size_y_zero_pt,
(block_size_x_b + pack_ratio - 1) / pack_ratio,
block_size_y_zero_pt,
reg_layout::tiled>;
using zero_pt_tile_desc_t = std::conditional_t<
quant_info_.quant_mode != quant_mode::I4_ASYM_FP_ZERO,
subgroup::tile_desc_t<
(tile_size_x_b + pack_ratio - 1) / pack_ratio,
tile_size_y_zero_pt,
(block_size_x_b + pack_ratio - 1) / pack_ratio,
block_size_y_zero_pt,
reg_layout::tiled>,
subgroup::tile_desc_t<
tile_size_x_b,
tile_size_y_zero_pt,
block_size_x_b,
block_size_y_zero_pt,
reg_layout::tiled>>;

using zero_pt_t = subgroup::tile_t<dtype_zero_pt, zero_pt_tile_desc_t>;
using zero_pt_payload_t = subgroup::mem_payload_t<
Expand Down Expand Up @@ -576,7 +590,8 @@ class gemm_t<
// TODO 1D prefetch need pack to U32/U64
subgroup::tile_prefetch<cache_hint::cached, cache_hint::cached>(
scale_prefetch_payload);
if constexpr (compute_policy::quant_mode != quant_mode::I4_SYM) {
if constexpr (
compute_policy::quant_mode != quant_mode::I4_SYM) {
// TODO 1D prefetch need pack to U32/U64
subgroup::tile_prefetch<cache_hint::cached, cache_hint::cached>(
zero_pt_prefetch_payload);
Expand All @@ -589,7 +604,8 @@ class gemm_t<
if (tile_k_idx % scale_addr_update_freq == 0) {
scale_payload.template update_tdesc<update_dir_b>(scale_t::tile_size_y);
}
if constexpr (compute_policy::quant_mode != quant_mode::I4_SYM) {
if constexpr (
compute_policy::quant_mode != quant_mode::I4_SYM) {
if (tile_k_idx % zero_pt_addr_update_freq == 0) {
zero_pt_payload.template update_tdesc<tdesc_update_dir::y_dir>(
zero_pt_t::tile_size_y);
Expand All @@ -603,7 +619,8 @@ class gemm_t<
if ((scale_prefetch_addr_i % scale_addr_update_freq) == 0) {
scale_prefetch_payload.template update_tdesc<tdesc_update_dir::y_dir>(
scale_t::tile_size_y);
if constexpr (compute_policy::quant_mode != quant_mode::I4_SYM) {
if constexpr (
compute_policy::quant_mode != quant_mode::I4_SYM) {
zero_pt_prefetch_payload
.template update_tdesc<tdesc_update_dir::y_dir>(
zero_pt_t::tile_size_y);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -570,7 +570,7 @@ class gemm_universal_t<
// check for int4x2
implementable &=
((args.matB_ld % pack_ratio == 0) && (args.matrix_n % pack_ratio == 0));
if constexpr (gemm_t::compute_policy::quant_mode != quant_mode::I4_SYM) {
if constexpr (gemm_t::compute_policy::quant_mode == quant_mode::I4_ASYM) {
implementable &= (args.zero_pt_ld % pack_ratio == 0);
}

Expand Down Expand Up @@ -621,7 +621,10 @@ class gemm_universal_t<
int start_x_scale = start_n;
int start_y_scale = start_k / dequant_s;

int start_x_zero_pt = start_n / pack_ratio;
int start_x_zero_pt =
gemm_t::compute_policy::quant_mode == quant_mode::I4_ASYM_FP_ZERO
? start_n
: start_n / pack_ratio;
int start_y_zero_pt = start_k / dequant_s;

// set up arguments
Expand Down Expand Up @@ -674,7 +677,8 @@ class gemm_universal_t<
inner_loop_start,
inner_loop_count,
mem_desc_scale);
} else {
} else if constexpr (
gemm_t::compute_policy::quant_mode == quant_mode::I4_ASYM) {
mem_desc_zero_pt_t mem_desc_zero_pt(
args.zero_pt_base,
{(args.matrix_n + pack_ratio - 1) / pack_ratio,
Expand All @@ -688,6 +692,23 @@ class gemm_universal_t<
inner_loop_count,
mem_desc_scale,
mem_desc_zero_pt);
} else if constexpr (
gemm_t::compute_policy::quant_mode == quant_mode::I4_ASYM_FP_ZERO) {
mem_desc_zero_pt_t mem_desc_zero_pt(
args.zero_pt_base,
{args.matrix_n,
((args.matrix_k + dequant_s - 1) / dequant_s),
args.zero_pt_ld},
{start_x_zero_pt, start_y_zero_pt});
gemm_args = gemm_args_t(
mem_desc_a,
mem_desc_b,
inner_loop_start,
inner_loop_count,
mem_desc_scale,
mem_desc_zero_pt);
} else {
assert(0);
}
matAcc_t matAcc;
matAcc.init(0);
Expand Down
14 changes: 12 additions & 2 deletions include/subgroup/tile/impl/tile_op_functor.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -149,15 +149,25 @@ struct dequant_int4_weight_t {
cvt_blk_i8.xetla_select<step, 1>(jj * block_size_y_b + ii) =
cvt_blk_i8.xetla_select<step, 1>(jj * block_size_y_b + ii) -
zero_pt_i8;
} else if constexpr (quant_mode == quant_mode::I4_SYM) {
} else if constexpr (
quant_mode == quant_mode::I4_SYM ||
quant_mode == quant_mode::I4_ASYM_FP_ZERO) {
cvt_blk_i8.xetla_select<step, 1>(jj * block_size_y_b + ii) =
cvt_blk_i8.xetla_select<step, 1>(jj * block_size_y_b + ii) -
int8_t(8);
}
dst_blk.xetla_select<step, 1>(jj * block_size_y_b + ii) =
cvt_blk_i8.xetla_select<step, 1>(jj * block_size_y_b + ii) *
scale.reg[scale_idx];

if constexpr (quant_mode == quant_mode::I4_ASYM_FP_ZERO) {
uint32_t zero_pt_idx =
offset_y_in_tile / dequant_s * zero_pt_t::block_size_x +
offset_x_in_tile;
xetla_vector<fp16, 1> zero_pt_pack = zero_pt.reg[zero_pt_idx];
dst_blk.xetla_select<step, 1>(jj * block_size_y_b + ii) =
dst_blk.xetla_select<step, 1>(jj * block_size_y_b + ii) +
zero_pt_pack[0];
}
// sycl::ext::oneapi::experimental::printf(
// "scale[%d] %f \n",
// scale_idx,
Expand Down
49 changes: 34 additions & 15 deletions tests/integration/gemv/int4/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ constexpr int ITER = 200;
#endif
constexpr size_t UNDEFINED_DATA_SIZE = 1024;

template <typename scalar_t>
template <typename scalar_t, quant_mode quant_mode_>
class test_col_major_1 {
public:
// Extract the parameters required by different test cases
Expand All @@ -41,7 +41,7 @@ class test_col_major_1 {
static constexpr size_t sg_k = 512 / sg_m;
static constexpr size_t dequant_s = 128;
// static constexpr quant_mode quant_mode = quant_mode::I4_ASYM;
static constexpr quant_mode quant_mode = quant_mode::I4_SYM;
static constexpr quant_mode quant_mode = quant_mode_;

static constexpr size_t local_kslicing = 1;
static constexpr size_t global_kslicing = 1;
Expand Down Expand Up @@ -132,13 +132,19 @@ std::vector<fp16> convert_int4(
data_type_zero_pt zero_pt) {
std::vector<fp16> dequant_fp16(sizeof(data_type_b) * 2);

int8_t zero_pt_i8 = zero_pt & 0xf;
int8_t zero_pt_i8;
if constexpr (quant_mode != quant_mode::I4_ASYM_FP_ZERO)
zero_pt_i8 = zero_pt & 0xf;
for (uint32_t i = 0; i < dequant_fp16.size(); i++) {
int8_t dequant_8bit = data_b & 0xf;
if constexpr (quant_mode == quant_mode::I4_SYM) {
dequant_fp16[i] = scale * (dequant_8bit - 8);
} else {
} else if constexpr (quant_mode == quant_mode::I4_ASYM) {
dequant_fp16[i] = scale * (dequant_8bit - zero_pt_i8);
} else if constexpr (quant_mode == quant_mode::I4_ASYM_FP_ZERO) {
dequant_fp16[i] = scale * (dequant_8bit - 8) + zero_pt;
} else {
assert(0);
}
data_b = data_b >> 4;
}
Expand Down Expand Up @@ -170,12 +176,14 @@ std::vector<data_type_acc_in> dequantize_weight(
for (uint32_t j = 0; j < width; j += step) {
int start_b_in = i * width + j;
int start_scale_in = start_b_in / step;
int start_zero_pt_in =
(j / step) * (matrix_n / pack_radio) + i / pack_radio;
int start_zero_pt_in = quant_mode == quant_mode::I4_ASYM_FP_ZERO
? (j / step) * matrix_n + i
: (j / step) * (matrix_n / pack_radio) + i / pack_radio;
int start_out =
layout_b == mem_layout::row_major ? 0 : i * matrix_k + j * pack_radio;
data_type_zero_pt zp_value = zero_pt[start_zero_pt_in];
zp_value = zp_value >> (4 * (i % pack_radio));
if constexpr (quant_mode != quant_mode::I4_ASYM_FP_ZERO)
zp_value = zp_value >> (4 * (i % pack_radio));
for (uint32_t jj = 0; jj < step; jj++) {
std::vector<fp16> dequant_fp16 = convert_int4<quant_mode>(
b[start_b_in + jj], scale[start_scale_in], zp_value);
Expand Down Expand Up @@ -216,7 +224,10 @@ void dequantize_gemv_run(int iter) {
using data_type_a = typename Test::data_type_a;
using data_type_b = typename Test::data_type_b;
using data_type_c = typename Test::data_type_c;
using data_type_zero_pt = data_type_b;
using data_type_zero_pt = std::conditional_t<
Test::quant_mode == quant_mode::I4_ASYM_FP_ZERO,
data_type_c,
data_type_b>;
using data_type_scale = fp16;
using data_type_acc_in = fp16;
using data_type_acc = float;
Expand All @@ -226,7 +237,7 @@ void dequantize_gemv_run(int iter) {
constexpr mem_layout layout_b = Test::layout_b;

constexpr size_t size_a = matrix_m * matrix_k;
constexpr size_t size_b = matrix_k * matrix_n / (2 * sizeof(data_type_b));
constexpr size_t size_b = matrix_k * matrix_n / 2;

constexpr size_t size_scale_k = matrix_k / dequant_s;
constexpr size_t size_scale_n = matrix_n;
Expand All @@ -235,7 +246,9 @@ void dequantize_gemv_run(int iter) {
constexpr size_t size_zero_pt_k = matrix_k / dequant_s;
constexpr size_t size_zero_pt_n = matrix_n;
constexpr size_t size_zero_pt =
size_zero_pt_k * size_zero_pt_n / (2 * sizeof(data_type_b));
Test::quant_mode != quant_mode::I4_ASYM_FP_ZERO
? size_zero_pt_k * size_zero_pt_n / 2
: size_zero_pt_k * size_zero_pt_n;

constexpr size_t size_c = matrix_m * matrix_n;
constexpr size_t size_bias = matrix_n;
Expand Down Expand Up @@ -406,16 +419,18 @@ void dequantize_gemv_run(int iter) {
scale_h[i] = INFINITY;
}
for (unsigned i = 0; i < size_zero_pt + UNDEFINED_DATA_SIZE; ++i) {
if constexpr (std::is_same_v<int4x2, data_type_b>) {
if constexpr (std::is_same_v<int4x2, data_type_zero_pt>) {
zero_pt_h[i] = random_uint8();
#ifdef UT_DEBUG
zero_pt_h[i] = 0x12 << i;
#endif
} else if constexpr (std::is_same_v<int4x8, data_type_b>) {
} else if constexpr (std::is_same_v<int4x8, data_type_zero_pt>) {
zero_pt_h[i] = random_uint32();
#ifdef UT_DEBUG
zero_pt_h[i] = 0x33333333;
#endif
} else if constexpr (std::is_same_v<fp16, data_type_zero_pt>) {
zero_pt_h[i] = random_float();
}
}

Expand Down Expand Up @@ -492,7 +507,9 @@ void dequantize_gemv_run(int iter) {
Acc_d,
Cnt_d,
epilogue_args);
} else if constexpr (compute_policy::quant_mode == quant_mode::I4_ASYM) {
} else if constexpr (
compute_policy::quant_mode == quant_mode::I4_ASYM ||
compute_policy::quant_mode == quant_mode::I4_ASYM_FP_ZERO) {
gemm_arg =
typename gemm_op_t::template arguments_t<compute_policy::quant_mode>(
matrix_m,
Expand Down Expand Up @@ -604,8 +621,10 @@ TYPED_TEST_P(dequantize_gemv_test, esimd) {

REGISTER_TYPED_TEST_SUITE_P(dequantize_gemv_test, esimd);
using tests = ::testing::Types< //
test_col_major_1<fp16>,
test_col_major_1<bf16>,
test_col_major_1<fp16, quant_mode::I4_SYM>,
test_col_major_1<bf16, quant_mode::I4_SYM>,
test_col_major_1<fp16, quant_mode::I4_ASYM_FP_ZERO>,
test_col_major_1<bf16, quant_mode::I4_ASYM_FP_ZERO>,
// test_col_major_2,
void>;

Expand Down

0 comments on commit 074e76d

Please sign in to comment.