From e2eec504d65945f957892a779c30f2fb343ba7dd Mon Sep 17 00:00:00 2001 From: "Wang,Zhe" Date: Thu, 18 Jan 2024 08:52:41 +0800 Subject: [PATCH] ut --- bestla/bestla/ut/bestla_prologue_b.cpp | 129 ++++++++++++++----------- 1 file changed, 71 insertions(+), 58 deletions(-) diff --git a/bestla/bestla/ut/bestla_prologue_b.cpp b/bestla/bestla/ut/bestla_prologue_b.cpp index 3280e4e4f..96d8e1af9 100644 --- a/bestla/bestla/ut/bestla_prologue_b.cpp +++ b/bestla/bestla/ut/bestla_prologue_b.cpp @@ -177,17 +177,20 @@ class UT_TransposeBlockQuantize_F4 { UT_TransposeBlockQuantize_F4() { UT_START(); CheckISA(AVX512F); - ut(4096, 4096, 32, BTLA_DTYPE::F4_BNB); - ut(1024, 4096, 32, BTLA_DTYPE::F4_BNB); - ut(4096, 1024, 32, BTLA_DTYPE::F4_BNB); - ut(48, 32, 32, BTLA_DTYPE::F4_BNB); - ut(32, 32, 32, BTLA_DTYPE::F4_BNB); - ut(48, 32, 32, BTLA_DTYPE::F4_BNB); - ut(48, 32, 32, BTLA_DTYPE::F4_NF4); - ut(48, 32, 32, BTLA_DTYPE::F4_E2M1); + ut(4096, 4096, 32, BTLA_DTYPE::F4_BNB, BTLA_DTYPE::F32); + ut(1024, 4096, 32, BTLA_DTYPE::F4_BNB, BTLA_DTYPE::F32); + ut(4096, 1024, 32, BTLA_DTYPE::F4_BNB, BTLA_DTYPE::F32); + ut(48, 32, 32, BTLA_DTYPE::F4_BNB, BTLA_DTYPE::F32); + ut(32, 32, 32, BTLA_DTYPE::F4_BNB, BTLA_DTYPE::F32); + ut(48, 32, 32, BTLA_DTYPE::F4_BNB, BTLA_DTYPE::F32); + ut(48, 32, 32, BTLA_DTYPE::F4_NF4, BTLA_DTYPE::F32); + ut(48, 32, 32, BTLA_DTYPE::F4_E2M1, BTLA_DTYPE::F32); + ut(16, 15, 8, BTLA_DTYPE::F4_NF4, BTLA_DTYPE::DQ8_BNB); + ut(48, 32, 16, BTLA_DTYPE::F4_NF4, BTLA_DTYPE::DQ8_BNB); + ut(1024, 4096, 32, BTLA_DTYPE::F4_NF4, BTLA_DTYPE::DQ8_BNB); } - void ut(int n, int k, int blocksize, BTLA_DTYPE F4_T) { + void ut(int n, int k, int blocksize, BTLA_DTYPE F4_T, BTLA_DTYPE SCA_T) { printf("Test Case: %d %d %d\n", n, k, blocksize); int ldb = n; utils::aligned_vector dequanRef(n * k); @@ -246,18 +249,23 @@ class UT_TransposeBlockQuantize_F4 { auto constexpr RuntimeISA = BTLA_ISA::AVX512F; using PrologueB = prologue_b::gemm::WeightKBlockNFloat, RuntimeISA>; PrologueB kernel; - auto packedW = kernel.createStorage(n, k, blocksize, F4_T, bestla_dtype); - auto packedW1 = kernel.createStorage(n, k, blocksize, F4_T, bestla_dtype); + auto packedW = kernel.createStorage(n, k, blocksize, F4_T, SCA_T); + auto packedW1 = kernel.createStorage(n, k, blocksize, F4_T, SCA_T); avector buf(packedW.mSize), buf1(packedW1.mSize); packedW.assign(buf.data()); packedW1.assign(buf1.data()); kernel.packTransposeWeight(n, k, dequanRef.data(), k, &packedW, &DefaultThreading); kernel.packQWeight(n, k, quanW.data(), ldb, scales.data(), nullptr, &packedW1, &DefaultThreading); - ut::buffer_error(packedW.SPtr(), packedW1.SPtr(), packedW1.CSize()); - ut::buffer_error(packedW.WPtr(), packedW1.WPtr(), packedW1.mQBuf.size()); avector dequant(n * k); kernel.unpackTransposeWeight(n, k, &packedW1, dequant.data(), k, &DefaultThreading); - ut::buffer_error(dequanRef.data(), dequant.data(), dequant.size()); + if (SCA_T != BTLA_DTYPE::DQ8_BNB) { + ut::buffer_error(packedW.SPtr(), packedW1.SPtr(), packedW1.CSize()); + ut::buffer_error(dequanRef.data(), dequant.data(), dequant.size()); + } else { + ut::buffer_error(packedW.SPtr(), packedW1.SPtr(), packedW1.CSize()); + ut::buffer_error(dequanRef.data(), dequant.data(), dequant.size(), 0.1f); + } + ut::buffer_error(packedW.WPtr(), packedW1.WPtr(), packedW1.mQBuf.size()); } }; #ifdef BTLA_UT_PROLOGUE_B @@ -523,15 +531,15 @@ class UT_CompFp32 { void ut_f8() { CheckISA(AVX2); - ut(2, 4096, 4096, 32, BTLA_DTYPE::F8_E4M3); - ut(2, 4096, 4096, 32, BTLA_DTYPE::F8_E4M3); - ut(2, 4096, 4096, 32, BTLA_DTYPE::F8_E5M2); - ut(2, 4096, 4096, 32, BTLA_DTYPE::F8_E5M2); + ut(2, 4096, 4096, 32, BTLA_DTYPE::F8_E4M3, BTLA_DTYPE::F8_E8M0); + ut(2, 4096, 4096, 32, BTLA_DTYPE::F8_E4M3, BTLA_DTYPE::F32); + ut(2, 4096, 4096, 32, BTLA_DTYPE::F8_E5M2, BTLA_DTYPE::F8_E8M0); + ut(2, 4096, 4096, 32, BTLA_DTYPE::F8_E5M2, BTLA_DTYPE::F32); CheckISA(AVX512F); - ut(2, 4096, 4096, 32, BTLA_DTYPE::F8_E4M3); - ut(2, 4096, 4096, 32, BTLA_DTYPE::F8_E4M3); - ut(2, 4096, 4096, 32, BTLA_DTYPE::F8_E5M2); - ut(2, 4096, 4096, 32, BTLA_DTYPE::F8_E5M2); + ut(2, 4096, 4096, 32, BTLA_DTYPE::F8_E4M3, BTLA_DTYPE::F8_E8M0); + ut(2, 4096, 4096, 32, BTLA_DTYPE::F8_E4M3, BTLA_DTYPE::F32); + ut(2, 4096, 4096, 32, BTLA_DTYPE::F8_E5M2, BTLA_DTYPE::F8_E8M0); + ut(2, 4096, 4096, 32, BTLA_DTYPE::F8_E5M2, BTLA_DTYPE::F32); } void ut_s4() { CheckISA(AVX2); @@ -589,26 +597,26 @@ class UT_CompFp32 { void ut_f4() { CheckISA(AVX2); - ut(2, 4096, 4096, 32, BTLA_DTYPE::F4_BNB); - ut(2, 4096, 4096, -1, BTLA_DTYPE::F4_BNB); - ut(2, 4096, 4096, 32, BTLA_DTYPE::F4_E2M1); - ut(2, 4096, 4096, -1, BTLA_DTYPE::F4_E2M1); - ut(2, 4096, 4096, 32, BTLA_DTYPE::F4_NF4); - ut(2, 4096, 4096, -1, BTLA_DTYPE::F4_NF4); - ut(2, 4096, 4096, 32, BTLA_DTYPE::F4_BNB); - ut(2, 4096, 4096, 32, BTLA_DTYPE::F4_E2M1); - ut(2, 4096, 4096, 32, BTLA_DTYPE::F4_NF4); + ut(2, 4096, 4096, 32, BTLA_DTYPE::F4_BNB, BTLA_DTYPE::F32); + ut(2, 4096, 4096, -1, BTLA_DTYPE::F4_BNB, BTLA_DTYPE::F32); + ut(2, 4096, 4096, 32, BTLA_DTYPE::F4_E2M1, BTLA_DTYPE::F32); + ut(2, 4096, 4096, -1, BTLA_DTYPE::F4_E2M1, BTLA_DTYPE::F32); + ut(2, 4096, 4096, 32, BTLA_DTYPE::F4_NF4, BTLA_DTYPE::F32); + ut(2, 4096, 4096, -1, BTLA_DTYPE::F4_NF4, BTLA_DTYPE::F32); + ut(2, 4096, 4096, 32, BTLA_DTYPE::F4_BNB, BTLA_DTYPE::BF16); + ut(2, 4096, 4096, 32, BTLA_DTYPE::F4_E2M1, BTLA_DTYPE::BF16); + ut(2, 4096, 4096, 32, BTLA_DTYPE::F4_NF4, BTLA_DTYPE::BF16); CheckISA(AVX512F); - ut(2, 4096, 4096, 32, BTLA_DTYPE::F4_BNB); - ut(2, 4096, 4096, -1, BTLA_DTYPE::F4_BNB); - ut(2, 4096, 4096, 32, BTLA_DTYPE::F4_E2M1); - ut(2, 4096, 4096, -1, BTLA_DTYPE::F4_E2M1); - ut(2, 4096, 4096, 32, BTLA_DTYPE::F4_NF4); - ut(2, 4096, 4096, -1, BTLA_DTYPE::F4_NF4); - ut(2, 4096, 4096, 32, BTLA_DTYPE::F4_BNB); - ut(2, 4096, 4096, 32, BTLA_DTYPE::F4_E2M1); - ut(2, 4096, 4096, 32, BTLA_DTYPE::F4_NF4); + ut(2, 4096, 4096, 32, BTLA_DTYPE::F4_BNB, BTLA_DTYPE::F32); + ut(2, 4096, 4096, -1, BTLA_DTYPE::F4_BNB, BTLA_DTYPE::F32); + ut(2, 4096, 4096, 32, BTLA_DTYPE::F4_E2M1, BTLA_DTYPE::F32); + ut(2, 4096, 4096, -1, BTLA_DTYPE::F4_E2M1, BTLA_DTYPE::F32); + ut(2, 4096, 4096, 32, BTLA_DTYPE::F4_NF4, BTLA_DTYPE::F32); + ut(2, 4096, 4096, -1, BTLA_DTYPE::F4_NF4, BTLA_DTYPE::F32); + ut(2, 4096, 4096, 32, BTLA_DTYPE::F4_BNB, BTLA_DTYPE::BF16); + ut(2, 4096, 4096, 32, BTLA_DTYPE::F4_E2M1, BTLA_DTYPE::BF16); + ut(2, 4096, 4096, 32, BTLA_DTYPE::F4_NF4, BTLA_DTYPE::BF16); } template class Wei> @@ -664,10 +672,10 @@ class UT_CompFp32 { buffer_error(refCupk.data(), matC.data(), refCupk.size(), 0.001f); } - template class Wei, typename Scale_T> - void ut(int m, int n, int k, int blocksize, BTLA_DTYPE qtype) { + template class Wei> + void ut(int m, int n, int k, int blocksize, BTLA_DTYPE qtype, BTLA_DTYPE stype) { printf("Test Case %s: %d %d %d-%d type:%s core:%s scaletype:%s\n", __FUNCTION__, m, n, k, blocksize, - bestla_dtype_str(qtype), gemm::CoreAttr::to_str(GemmCore_T::ID), type_str); + bestla_dtype_str(qtype), gemm::CoreAttr::to_str(GemmCore_T::ID), bestla_dtype_str(stype)); auto constexpr ISA = GemmCore_T::ISA; using Launcher = wrapper::gemm::LauncherKBlock::StorageWeight; WType packedw(0); - packedw = launcher.mProB.createStorage(n, k, blocksize, qtype, bestla_dtype); - + packedw = launcher.mProB.createStorage(n, k, blocksize, qtype, stype); utils::avector buffer(packedw.mSize); packedw.assign(buffer.data()); avector matBf32(k * n), matAf32(m * k), matC(m * n), refC(m * n), refCupk(m * n); @@ -946,19 +953,22 @@ class UT_CompInt8 { void ut_s4_newkblock() { GetCPUDevice(); if (_cd->AVX_VNNI()) { - ut_newkblock, float>(1, 11008, 4096, 32, BTLA_DTYPE::S4_CLIP); - ut_newkblock, float>(2, 4096, 4096, 32, BTLA_DTYPE::S4_CLIP); + ut_newkblock>(1, 11008, 4096, 32, BTLA_DTYPE::S4_CLIP, BTLA_DTYPE::F32); + ut_newkblock>(2, 4096, 4096, 32, BTLA_DTYPE::S4_CLIP, BTLA_DTYPE::F32); + ut_newkblock>(2, 4096, 4096, 32, BTLA_DTYPE::S4_CLIP, BTLA_DTYPE::DQ8_BNB); } if (_cd->AVX512_VNNI()) { - ut_newkblock, float>(1, 11008, 4096, 32, BTLA_DTYPE::S4_CLIP); - ut_newkblock, float>(2, 4096, 4096, 32, BTLA_DTYPE::S4_CLIP); + ut_newkblock>(1, 11008, 4096, 32, BTLA_DTYPE::S4_CLIP, BTLA_DTYPE::F32); + ut_newkblock>(2, 4096, 4096, 32, BTLA_DTYPE::S4_CLIP, BTLA_DTYPE::F32); } if (_cd->AMX_INT8()) { request_perm_xtile_data(); - ut_newkblock, float>(128, 4096, 4096, 128, BTLA_DTYPE::S4_CLIP); - ut_newkblock, float>(1, 4096, 4096, 64, BTLA_DTYPE::S4_CLIP); + ut_newkblock>(128, 4096, 4096, 128, BTLA_DTYPE::S4_CLIP, BTLA_DTYPE::F32); + ut_newkblock>(1, 4096, 4096, 64, BTLA_DTYPE::S4_CLIP, BTLA_DTYPE::F32); + ut_newkblock>(128, 4096, 4096, 128, BTLA_DTYPE::S4_CLIP, + BTLA_DTYPE::DQ8_BNB); } } @@ -986,10 +996,10 @@ class UT_CompInt8 { } } - template - void ut_newkblock(int m, int n, int k, int blocksize, BTLA_DTYPE qtype, bool isAsym = false) { + template + void ut_newkblock(int m, int n, int k, int blocksize, BTLA_DTYPE qtype, BTLA_DTYPE stype, bool isAsym = false) { printf("Test Case %s: %d %d %d-%d type:%s core:%s scaletype:%s Asym:%d\n", __FUNCTION__, m, n, k, blocksize, - bestla_dtype_str(qtype), gemm::CoreAttr::to_str(GemmCore_T::ID), type_str, isAsym); + bestla_dtype_str(qtype), gemm::CoreAttr::to_str(GemmCore_T::ID), bestla_dtype_str(stype), isAsym); auto constexpr ISA = GemmCore_T::ISA; using Launcher = wrapper::gemm::LauncherIntKBlock, bestla_dtype, isAsym); - + WType packedw = launcher.mProB.createStorage(n, k, blocksize, qtype, stype, bestla_dtype, isAsym); utils::avector buffer(packedw.mSize); packedw.assign(buffer.data()); avector matBf32(k * n), matAf32(m * k), matC(m * n), refC(m * n), refCupk(m * n); @@ -1042,7 +1050,12 @@ class UT_CompInt8 { } } buffer_error(refC.data(), matC.data(), refC.size(), err); - buffer_error(refCupk.data(), matC.data(), refCupk.size(), INT8_ERR); // dynamic quant error + if (stype != BTLA_DTYPE::DQ8_BNB) { + buffer_error(refCupk.data(), matC.data(), refCupk.size(), INT8_ERR); // dynamic quant error + } else { + auto DQ_INT8_ERR = 0.8f; + buffer_error(refCupk.data(), matC.data(), refCupk.size(), DQ_INT8_ERR); // dynamic quant error + } } template