Skip to content
This repository has been archived by the owner on Aug 30, 2024. It is now read-only.

Commit

Permalink
ut
Browse files Browse the repository at this point in the history
  • Loading branch information
zhewang1-intc committed Jan 18, 2024
1 parent b9d6e6f commit e2eec50
Showing 1 changed file with 71 additions and 58 deletions.
129 changes: 71 additions & 58 deletions bestla/bestla/ut/bestla_prologue_b.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -177,17 +177,20 @@ class UT_TransposeBlockQuantize_F4 {
UT_TransposeBlockQuantize_F4() {
UT_START();
CheckISA(AVX512F);
ut(4096, 4096, 32, BTLA_DTYPE::F4_BNB);
ut(1024, 4096, 32, BTLA_DTYPE::F4_BNB);
ut(4096, 1024, 32, BTLA_DTYPE::F4_BNB);
ut(48, 32, 32, BTLA_DTYPE::F4_BNB);
ut(32, 32, 32, BTLA_DTYPE::F4_BNB);
ut(48, 32, 32, BTLA_DTYPE::F4_BNB);
ut(48, 32, 32, BTLA_DTYPE::F4_NF4);
ut(48, 32, 32, BTLA_DTYPE::F4_E2M1);
ut(4096, 4096, 32, BTLA_DTYPE::F4_BNB, BTLA_DTYPE::F32);
ut(1024, 4096, 32, BTLA_DTYPE::F4_BNB, BTLA_DTYPE::F32);
ut(4096, 1024, 32, BTLA_DTYPE::F4_BNB, BTLA_DTYPE::F32);
ut(48, 32, 32, BTLA_DTYPE::F4_BNB, BTLA_DTYPE::F32);
ut(32, 32, 32, BTLA_DTYPE::F4_BNB, BTLA_DTYPE::F32);
ut(48, 32, 32, BTLA_DTYPE::F4_BNB, BTLA_DTYPE::F32);
ut(48, 32, 32, BTLA_DTYPE::F4_NF4, BTLA_DTYPE::F32);
ut(48, 32, 32, BTLA_DTYPE::F4_E2M1, BTLA_DTYPE::F32);
ut(16, 15, 8, BTLA_DTYPE::F4_NF4, BTLA_DTYPE::DQ8_BNB);
ut(48, 32, 16, BTLA_DTYPE::F4_NF4, BTLA_DTYPE::DQ8_BNB);
ut(1024, 4096, 32, BTLA_DTYPE::F4_NF4, BTLA_DTYPE::DQ8_BNB);
}

void ut(int n, int k, int blocksize, BTLA_DTYPE F4_T) {
void ut(int n, int k, int blocksize, BTLA_DTYPE F4_T, BTLA_DTYPE SCA_T) {
printf("Test Case: %d %d %d\n", n, k, blocksize);
int ldb = n;
utils::aligned_vector<float> dequanRef(n * k);
Expand Down Expand Up @@ -246,18 +249,23 @@ class UT_TransposeBlockQuantize_F4 {
auto constexpr RuntimeISA = BTLA_ISA::AVX512F;
using PrologueB = prologue_b::gemm::WeightKBlockNFloat<gemm::SCoreRowNAvx512f<48, 8>, RuntimeISA>;
PrologueB kernel;
auto packedW = kernel.createStorage(n, k, blocksize, F4_T, bestla_dtype<float>);
auto packedW1 = kernel.createStorage(n, k, blocksize, F4_T, bestla_dtype<float>);
auto packedW = kernel.createStorage(n, k, blocksize, F4_T, SCA_T);
auto packedW1 = kernel.createStorage(n, k, blocksize, F4_T, SCA_T);
avector<int8_t> buf(packedW.mSize), buf1(packedW1.mSize);
packedW.assign(buf.data());
packedW1.assign(buf1.data());
kernel.packTransposeWeight(n, k, dequanRef.data(), k, &packedW, &DefaultThreading);
kernel.packQWeight(n, k, quanW.data(), ldb, scales.data(), nullptr, &packedW1, &DefaultThreading);
ut::buffer_error(packedW.SPtr<float>(), packedW1.SPtr<float>(), packedW1.CSize());
ut::buffer_error(packedW.WPtr<int8_t>(), packedW1.WPtr<int8_t>(), packedW1.mQBuf.size<int8_t>());
avector<float> dequant(n * k);
kernel.unpackTransposeWeight(n, k, &packedW1, dequant.data(), k, &DefaultThreading);
ut::buffer_error(dequanRef.data(), dequant.data(), dequant.size());
if (SCA_T != BTLA_DTYPE::DQ8_BNB) {
ut::buffer_error(packedW.SPtr<float>(), packedW1.SPtr<float>(), packedW1.CSize());
ut::buffer_error(dequanRef.data(), dequant.data(), dequant.size());
} else {
ut::buffer_error(packedW.SPtr<int8_t>(), packedW1.SPtr<int8_t>(), packedW1.CSize());
ut::buffer_error(dequanRef.data(), dequant.data(), dequant.size(), 0.1f);
}
ut::buffer_error(packedW.WPtr<int8_t>(), packedW1.WPtr<int8_t>(), packedW1.mQBuf.size<int8_t>());
}
};
#ifdef BTLA_UT_PROLOGUE_B
Expand Down Expand Up @@ -523,15 +531,15 @@ class UT_CompFp32 {

void ut_f8() {
CheckISA(AVX2);
ut<sAVX2, prologue_b::gemm::WeightKBlockNFloat, f8>(2, 4096, 4096, 32, BTLA_DTYPE::F8_E4M3);
ut<sAVX2, prologue_b::gemm::WeightKBlockNFloat, float>(2, 4096, 4096, 32, BTLA_DTYPE::F8_E4M3);
ut<sAVX2, prologue_b::gemm::WeightKBlockNFloat, f8>(2, 4096, 4096, 32, BTLA_DTYPE::F8_E5M2);
ut<sAVX2, prologue_b::gemm::WeightKBlockNFloat, float>(2, 4096, 4096, 32, BTLA_DTYPE::F8_E5M2);
ut<sAVX2, prologue_b::gemm::WeightKBlockNFloat>(2, 4096, 4096, 32, BTLA_DTYPE::F8_E4M3, BTLA_DTYPE::F8_E8M0);
ut<sAVX2, prologue_b::gemm::WeightKBlockNFloat>(2, 4096, 4096, 32, BTLA_DTYPE::F8_E4M3, BTLA_DTYPE::F32);
ut<sAVX2, prologue_b::gemm::WeightKBlockNFloat>(2, 4096, 4096, 32, BTLA_DTYPE::F8_E5M2, BTLA_DTYPE::F8_E8M0);
ut<sAVX2, prologue_b::gemm::WeightKBlockNFloat>(2, 4096, 4096, 32, BTLA_DTYPE::F8_E5M2, BTLA_DTYPE::F32);
CheckISA(AVX512F);
ut<sAVX512F, prologue_b::gemm::WeightKBlockNFloat, f8>(2, 4096, 4096, 32, BTLA_DTYPE::F8_E4M3);
ut<sAVX512F, prologue_b::gemm::WeightKBlockNFloat, float>(2, 4096, 4096, 32, BTLA_DTYPE::F8_E4M3);
ut<sAVX512F, prologue_b::gemm::WeightKBlockNFloat, f8>(2, 4096, 4096, 32, BTLA_DTYPE::F8_E5M2);
ut<sAVX512F, prologue_b::gemm::WeightKBlockNFloat, float>(2, 4096, 4096, 32, BTLA_DTYPE::F8_E5M2);
ut<sAVX512F, prologue_b::gemm::WeightKBlockNFloat>(2, 4096, 4096, 32, BTLA_DTYPE::F8_E4M3, BTLA_DTYPE::F8_E8M0);
ut<sAVX512F, prologue_b::gemm::WeightKBlockNFloat>(2, 4096, 4096, 32, BTLA_DTYPE::F8_E4M3, BTLA_DTYPE::F32);
ut<sAVX512F, prologue_b::gemm::WeightKBlockNFloat>(2, 4096, 4096, 32, BTLA_DTYPE::F8_E5M2, BTLA_DTYPE::F8_E8M0);
ut<sAVX512F, prologue_b::gemm::WeightKBlockNFloat>(2, 4096, 4096, 32, BTLA_DTYPE::F8_E5M2, BTLA_DTYPE::F32);
}
void ut_s4() {
CheckISA(AVX2);
Expand Down Expand Up @@ -589,26 +597,26 @@ class UT_CompFp32 {

void ut_f4() {
CheckISA(AVX2);
ut<sAVX2, prologue_b::gemm::WeightKBlockNFloat, float>(2, 4096, 4096, 32, BTLA_DTYPE::F4_BNB);
ut<sAVX2, prologue_b::gemm::WeightKBlockNFloat, float>(2, 4096, 4096, -1, BTLA_DTYPE::F4_BNB);
ut<sAVX2, prologue_b::gemm::WeightKBlockNFloat, float>(2, 4096, 4096, 32, BTLA_DTYPE::F4_E2M1);
ut<sAVX2, prologue_b::gemm::WeightKBlockNFloat, float>(2, 4096, 4096, -1, BTLA_DTYPE::F4_E2M1);
ut<sAVX2, prologue_b::gemm::WeightKBlockNFloat, float>(2, 4096, 4096, 32, BTLA_DTYPE::F4_NF4);
ut<sAVX2, prologue_b::gemm::WeightKBlockNFloat, float>(2, 4096, 4096, -1, BTLA_DTYPE::F4_NF4);
ut<sAVX2, prologue_b::gemm::WeightKBlockNFloat, utils::bf16>(2, 4096, 4096, 32, BTLA_DTYPE::F4_BNB);
ut<sAVX2, prologue_b::gemm::WeightKBlockNFloat, utils::bf16>(2, 4096, 4096, 32, BTLA_DTYPE::F4_E2M1);
ut<sAVX2, prologue_b::gemm::WeightKBlockNFloat, utils::bf16>(2, 4096, 4096, 32, BTLA_DTYPE::F4_NF4);
ut<sAVX2, prologue_b::gemm::WeightKBlockNFloat>(2, 4096, 4096, 32, BTLA_DTYPE::F4_BNB, BTLA_DTYPE::F32);
ut<sAVX2, prologue_b::gemm::WeightKBlockNFloat>(2, 4096, 4096, -1, BTLA_DTYPE::F4_BNB, BTLA_DTYPE::F32);
ut<sAVX2, prologue_b::gemm::WeightKBlockNFloat>(2, 4096, 4096, 32, BTLA_DTYPE::F4_E2M1, BTLA_DTYPE::F32);
ut<sAVX2, prologue_b::gemm::WeightKBlockNFloat>(2, 4096, 4096, -1, BTLA_DTYPE::F4_E2M1, BTLA_DTYPE::F32);
ut<sAVX2, prologue_b::gemm::WeightKBlockNFloat>(2, 4096, 4096, 32, BTLA_DTYPE::F4_NF4, BTLA_DTYPE::F32);
ut<sAVX2, prologue_b::gemm::WeightKBlockNFloat>(2, 4096, 4096, -1, BTLA_DTYPE::F4_NF4, BTLA_DTYPE::F32);
ut<sAVX2, prologue_b::gemm::WeightKBlockNFloat>(2, 4096, 4096, 32, BTLA_DTYPE::F4_BNB, BTLA_DTYPE::BF16);
ut<sAVX2, prologue_b::gemm::WeightKBlockNFloat>(2, 4096, 4096, 32, BTLA_DTYPE::F4_E2M1, BTLA_DTYPE::BF16);
ut<sAVX2, prologue_b::gemm::WeightKBlockNFloat>(2, 4096, 4096, 32, BTLA_DTYPE::F4_NF4, BTLA_DTYPE::BF16);

CheckISA(AVX512F);
ut<sAVX512F, prologue_b::gemm::WeightKBlockNFloat, float>(2, 4096, 4096, 32, BTLA_DTYPE::F4_BNB);
ut<sAVX512F, prologue_b::gemm::WeightKBlockNFloat, float>(2, 4096, 4096, -1, BTLA_DTYPE::F4_BNB);
ut<sAVX512F, prologue_b::gemm::WeightKBlockNFloat, float>(2, 4096, 4096, 32, BTLA_DTYPE::F4_E2M1);
ut<sAVX512F, prologue_b::gemm::WeightKBlockNFloat, float>(2, 4096, 4096, -1, BTLA_DTYPE::F4_E2M1);
ut<sAVX512F, prologue_b::gemm::WeightKBlockNFloat, float>(2, 4096, 4096, 32, BTLA_DTYPE::F4_NF4);
ut<sAVX512F, prologue_b::gemm::WeightKBlockNFloat, float>(2, 4096, 4096, -1, BTLA_DTYPE::F4_NF4);
ut<sAVX512F, prologue_b::gemm::WeightKBlockNFloat, utils::bf16>(2, 4096, 4096, 32, BTLA_DTYPE::F4_BNB);
ut<sAVX512F, prologue_b::gemm::WeightKBlockNFloat, utils::bf16>(2, 4096, 4096, 32, BTLA_DTYPE::F4_E2M1);
ut<sAVX512F, prologue_b::gemm::WeightKBlockNFloat, utils::bf16>(2, 4096, 4096, 32, BTLA_DTYPE::F4_NF4);
ut<sAVX512F, prologue_b::gemm::WeightKBlockNFloat>(2, 4096, 4096, 32, BTLA_DTYPE::F4_BNB, BTLA_DTYPE::F32);
ut<sAVX512F, prologue_b::gemm::WeightKBlockNFloat>(2, 4096, 4096, -1, BTLA_DTYPE::F4_BNB, BTLA_DTYPE::F32);
ut<sAVX512F, prologue_b::gemm::WeightKBlockNFloat>(2, 4096, 4096, 32, BTLA_DTYPE::F4_E2M1, BTLA_DTYPE::F32);
ut<sAVX512F, prologue_b::gemm::WeightKBlockNFloat>(2, 4096, 4096, -1, BTLA_DTYPE::F4_E2M1, BTLA_DTYPE::F32);
ut<sAVX512F, prologue_b::gemm::WeightKBlockNFloat>(2, 4096, 4096, 32, BTLA_DTYPE::F4_NF4, BTLA_DTYPE::F32);
ut<sAVX512F, prologue_b::gemm::WeightKBlockNFloat>(2, 4096, 4096, -1, BTLA_DTYPE::F4_NF4, BTLA_DTYPE::F32);
ut<sAVX512F, prologue_b::gemm::WeightKBlockNFloat>(2, 4096, 4096, 32, BTLA_DTYPE::F4_BNB, BTLA_DTYPE::BF16);
ut<sAVX512F, prologue_b::gemm::WeightKBlockNFloat>(2, 4096, 4096, 32, BTLA_DTYPE::F4_E2M1, BTLA_DTYPE::BF16);
ut<sAVX512F, prologue_b::gemm::WeightKBlockNFloat>(2, 4096, 4096, 32, BTLA_DTYPE::F4_NF4, BTLA_DTYPE::BF16);
}

template <class GemmCore_T, template <class _T, BTLA_ISA> class Wei>
Expand Down Expand Up @@ -664,10 +672,10 @@ class UT_CompFp32 {
buffer_error(refCupk.data(), matC.data(), refCupk.size(), 0.001f);
}

template <class GemmCore_T, template <class _T, BTLA_ISA> class Wei, typename Scale_T>
void ut(int m, int n, int k, int blocksize, BTLA_DTYPE qtype) {
template <class GemmCore_T, template <class _T, BTLA_ISA> class Wei>
void ut(int m, int n, int k, int blocksize, BTLA_DTYPE qtype, BTLA_DTYPE stype) {
printf("Test Case %s: %d %d %d-%d type:%s core:%s scaletype:%s\n", __FUNCTION__, m, n, k, blocksize,
bestla_dtype_str(qtype), gemm::CoreAttr::to_str(GemmCore_T::ID), type_str<Scale_T>);
bestla_dtype_str(qtype), gemm::CoreAttr::to_str(GemmCore_T::ID), bestla_dtype_str(stype));
auto constexpr ISA = GemmCore_T::ISA;
using Launcher =
wrapper::gemm::LauncherKBlock<ISA, GemmCore_T, prologue_a::gemm::ActivationBase, Wei,
Expand All @@ -677,8 +685,7 @@ class UT_CompFp32 {
blocksize = blocksize == -1 ? k : blocksize;
using WType = typename Wei<GemmCore_T, ISA>::StorageWeight;
WType packedw(0);
packedw = launcher.mProB.createStorage(n, k, blocksize, qtype, bestla_dtype<Scale_T>);

packedw = launcher.mProB.createStorage(n, k, blocksize, qtype, stype);
utils::avector<int8_t> buffer(packedw.mSize);
packedw.assign(buffer.data());
avector<float> matBf32(k * n), matAf32(m * k), matC(m * n), refC(m * n), refCupk(m * n);
Expand Down Expand Up @@ -946,19 +953,22 @@ class UT_CompInt8 {
void ut_s4_newkblock() {
GetCPUDevice();
if (_cd->AVX_VNNI()) {
ut_newkblock<gemm::ICoreRowNAvxvnniKBlock<48, 1>, float>(1, 11008, 4096, 32, BTLA_DTYPE::S4_CLIP);
ut_newkblock<gemm::ICoreRowNAvxvnniKBlock<48, 1>, float>(2, 4096, 4096, 32, BTLA_DTYPE::S4_CLIP);
ut_newkblock<gemm::ICoreRowNAvxvnniKBlock<48, 1>>(1, 11008, 4096, 32, BTLA_DTYPE::S4_CLIP, BTLA_DTYPE::F32);
ut_newkblock<gemm::ICoreRowNAvxvnniKBlock<48, 1>>(2, 4096, 4096, 32, BTLA_DTYPE::S4_CLIP, BTLA_DTYPE::F32);
ut_newkblock<gemm::ICoreRowNAvxvnniKBlock<48, 1>>(2, 4096, 4096, 32, BTLA_DTYPE::S4_CLIP, BTLA_DTYPE::DQ8_BNB);
}

if (_cd->AVX512_VNNI()) {
ut_newkblock<gemm::ICoreRowNAvx512vnniKBlock<48, 4>, float>(1, 11008, 4096, 32, BTLA_DTYPE::S4_CLIP);
ut_newkblock<gemm::ICoreRowNAvx512vnniKBlock<48, 4>, float>(2, 4096, 4096, 32, BTLA_DTYPE::S4_CLIP);
ut_newkblock<gemm::ICoreRowNAvx512vnniKBlock<48, 4>>(1, 11008, 4096, 32, BTLA_DTYPE::S4_CLIP, BTLA_DTYPE::F32);
ut_newkblock<gemm::ICoreRowNAvx512vnniKBlock<48, 4>>(2, 4096, 4096, 32, BTLA_DTYPE::S4_CLIP, BTLA_DTYPE::F32);
}

if (_cd->AMX_INT8()) {
request_perm_xtile_data();
ut_newkblock<gemm::ICoreRowNAmxint8KBlock<48, 16>, float>(128, 4096, 4096, 128, BTLA_DTYPE::S4_CLIP);
ut_newkblock<gemm::ICoreRowNAmxint8KBlock<48, 16>, float>(1, 4096, 4096, 64, BTLA_DTYPE::S4_CLIP);
ut_newkblock<gemm::ICoreRowNAmxint8KBlock<48, 16>>(128, 4096, 4096, 128, BTLA_DTYPE::S4_CLIP, BTLA_DTYPE::F32);
ut_newkblock<gemm::ICoreRowNAmxint8KBlock<48, 16>>(1, 4096, 4096, 64, BTLA_DTYPE::S4_CLIP, BTLA_DTYPE::F32);
ut_newkblock<gemm::ICoreRowNAmxint8KBlock<48, 16>>(128, 4096, 4096, 128, BTLA_DTYPE::S4_CLIP,
BTLA_DTYPE::DQ8_BNB);
}
}

Expand Down Expand Up @@ -986,10 +996,10 @@ class UT_CompInt8 {
}
}

template <class GemmCore_T, typename Scale_T>
void ut_newkblock(int m, int n, int k, int blocksize, BTLA_DTYPE qtype, bool isAsym = false) {
template <class GemmCore_T>
void ut_newkblock(int m, int n, int k, int blocksize, BTLA_DTYPE qtype, BTLA_DTYPE stype, bool isAsym = false) {
printf("Test Case %s: %d %d %d-%d type:%s core:%s scaletype:%s Asym:%d\n", __FUNCTION__, m, n, k, blocksize,
bestla_dtype_str(qtype), gemm::CoreAttr::to_str(GemmCore_T::ID), type_str<Scale_T>, isAsym);
bestla_dtype_str(qtype), gemm::CoreAttr::to_str(GemmCore_T::ID), bestla_dtype_str(stype), isAsym);
auto constexpr ISA = GemmCore_T::ISA;
using Launcher = wrapper::gemm::LauncherIntKBlock<ISA, GemmCore_T, prologue_a::gemm::ActivationF32KBlockQuantize,
prologue_b::gemm::WeightKBlockNInteger,
Expand All @@ -999,9 +1009,7 @@ class UT_CompInt8 {
blocksize = blocksize == -1 ? k : blocksize;
int kblks = updiv(k, blocksize);
using WType = typename Launcher::PrologueB::StorageWeight;
WType packedw =
launcher.mProB.createStorage(n, k, blocksize, qtype, bestla_dtype<Scale_T>, bestla_dtype<float>, isAsym);

WType packedw = launcher.mProB.createStorage(n, k, blocksize, qtype, stype, bestla_dtype<float>, isAsym);
utils::avector<int8_t> buffer(packedw.mSize);
packedw.assign(buffer.data());
avector<float> matBf32(k * n), matAf32(m * k), matC(m * n), refC(m * n), refCupk(m * n);
Expand Down Expand Up @@ -1042,7 +1050,12 @@ class UT_CompInt8 {
}
}
buffer_error(refC.data(), matC.data(), refC.size(), err);
buffer_error(refCupk.data(), matC.data(), refCupk.size(), INT8_ERR); // dynamic quant error
if (stype != BTLA_DTYPE::DQ8_BNB) {
buffer_error(refCupk.data(), matC.data(), refCupk.size(), INT8_ERR); // dynamic quant error
} else {
auto DQ_INT8_ERR = 0.8f;
buffer_error(refCupk.data(), matC.data(), refCupk.size(), DQ_INT8_ERR); // dynamic quant error
}
}

template <class GemmCore_T, typename Scale_T>
Expand Down

0 comments on commit e2eec50

Please sign in to comment.