diff --git a/intel_extension_for_transformers/llm/library/jblas/jblas/jit_blas_epilogue.h b/intel_extension_for_transformers/llm/library/jblas/jblas/jit_blas_epilogue.h index 61ccb51aa42..53cbd17105f 100644 --- a/intel_extension_for_transformers/llm/library/jblas/jblas/jit_blas_epilogue.h +++ b/intel_extension_for_transformers/llm/library/jblas/jblas/jit_blas_epilogue.h @@ -137,13 +137,19 @@ template class ZpDequantInt32ToFp32 { public: struct Param { + // necessary float* C; int ldc; - uint8_t* zpA; - float* scalesA; int ldsa; - float* reduceB; + float* scalesA; float* scalesB; + // optional if A asym + uint8_t* zpA = nullptr; + float* reduceB = nullptr; + // optional if B asym + int8_t* zpB = nullptr; + float* reduceA = nullptr; + int K = 1; }; JBLAS_CODE forward(const int32_t* cacheptr, const int cachestep, const int M_offset, const int N_offset, const int M, const int N, const Param& _param) { @@ -155,9 +161,22 @@ class ZpDequantInt32ToFp32 { if (ret != JblasSuccess) { return ret; } - ret = kernel::wrapper::RemoveZeroPointBias::template forward( - cptr, _param.ldc, M, N, _param.zpA + M_offset * _param.ldsa, _param.scalesA + M_offset * _param.ldsa, - _param.ldsa, _param.reduceB + N_offset); + if (_param.zpA == nullptr && _param.zpB == nullptr) { + return ret; + } else if (_param.zpA != nullptr && _param.zpB == nullptr) { + ret = kernel::wrapper::RemoveZeroPointBias::template forward( + cptr, _param.ldc, M, N, _param.zpA + M_offset * _param.ldsa, _param.scalesA + M_offset * _param.ldsa, + _param.ldsa, _param.reduceB + N_offset); + } else if (_param.zpA == nullptr && _param.zpB != nullptr) { + ret = kernel::wrapper::RemoveZeroPointBias::template forward(cptr, _param.ldc, M, N, _param.zpB + N_offset, + _param.scalesB + N_offset, _param.ldsa, + _param.reduceA + M_offset * _param.ldsa); + } else { + ret = kernel::wrapper::RemoveZeroPointBias::template forward( + cptr, _param.ldc, M, N, _param.zpA + M_offset * _param.ldsa, _param.zpB + N_offset, + _param.scalesA + M_offset * _param.ldsa, _param.scalesB + N_offset, _param.ldsa, _param.K, + _param.reduceA + M_offset * _param.ldsa, _param.reduceB + N_offset); + } return ret; } }; diff --git a/intel_extension_for_transformers/llm/library/jblas/jblas/jit_blas_prologue.h b/intel_extension_for_transformers/llm/library/jblas/jblas/jit_blas_prologue.h index 5f4f57f9697..ea33b065714 100644 --- a/intel_extension_for_transformers/llm/library/jblas/jblas/jit_blas_prologue.h +++ b/intel_extension_for_transformers/llm/library/jblas/jblas/jit_blas_prologue.h @@ -819,9 +819,10 @@ class WeightPack { } void packWeightTranspose(const int N, const int K, const Param& _param) { - utils::aligned_vector B_NT(N * K); - transposeWeight(N, K, _param.B, _param.ldb, B_NT.data(), N); - return packWeight(N, K, {B_NT.data(), N, _param.packedW}); + auto B_NT = utils::amalloc((size_t)N * K); + transposeWeight(N, K, _param.B, _param.ldb, B_NT, N); + packWeight(N, K, {B_NT, N, _param.packedW}); + utils::afree(B_NT); } // from KxN int8 symmetric weight to packed N//NtilexKPadxNTile int4 weight diff --git a/intel_extension_for_transformers/llm/library/jblas/jblas/jit_blas_utils.h b/intel_extension_for_transformers/llm/library/jblas/jblas/jit_blas_utils.h index 27d5aaff766..6b583da6d67 100644 --- a/intel_extension_for_transformers/llm/library/jblas/jblas/jit_blas_utils.h +++ b/intel_extension_for_transformers/llm/library/jblas/jblas/jit_blas_utils.h @@ -12,7 +12,9 @@ // See the License for the specific language governing permissions and // limitations under the License. #pragma once +#ifdef _OPENMP #include +#endif #include #include @@ -20,6 +22,11 @@ #include #include #include +#ifdef _WIN32 +#include +#else +#include +#endif #include "jit_blas.h" #include "xbyak/xbyak_util.h" @@ -279,7 +286,7 @@ static inline _DSTT cast(_SRCT _src) { template <> int8_t cast(float _src) { - _src = _src >= 0.f ? _src + 0.5f : _src - 0.5f; + _src = roundf(_src); _src = std::min(_src, 127.f); _src = std::max(_src, -128.f); return static_cast(_src); @@ -320,11 +327,36 @@ _T deserialize(int8_t*& buf) { static inline int padto(int a, int b) { return updiv(a, b) * b; } static inline size_t padto(size_t a, int b) { return updiv(a, b) * b; } +template +static inline _T* amalloc(size_t _size, size_t _alignment = 64) { + if (_size == 0) { + return NULL; + } + auto psize = padto(_size * sizeof(_T), _alignment); +#ifdef _WIN32 + return (_T*)_aligned_malloc(psize, _alignment); +#else + return (_T*)aligned_alloc(_alignment, psize); +#endif +} + +static inline void afree(void* ptr) { + if (ptr == NULL) { + return; + } +#ifdef _WIN32 + _aligned_free(ptr); +#else + free(ptr); +#endif +} + template class aligned_vector { public: aligned_vector() : mRawsize(0), mPtr(nullptr), mAlignedsize(0) {} - aligned_vector(size_t _size, _T _val = _T(0)) { + aligned_vector(size_t _size) { resize(_size); } + aligned_vector(size_t _size, _T _val) { resize(_size); std::fill_n(mVec.begin(), mVec.size(), _val); } @@ -502,12 +534,15 @@ class CpuDevice { ADD_FLAG(AVX512_BF16); ADD_FLAG(AVX512_FP16); numcores = _cpu.getNumCores(Xbyak::util::IntelCpuTopologyLevel::CoreLevel); +#ifdef _OPENMP ompthreads = omp_get_max_threads(); - numthreads = std::min(numcores, ompthreads); -#ifdef FORCE_NUM_THREADS - numthreads = FORCE_NUM_THREADS; +#else + ompthreads = numcores; #endif + numthreads = std::min(numcores, ompthreads); +#ifdef _OPENMP omp_set_num_threads(numthreads); +#endif } static CpuDevice* getInstance() { diff --git a/intel_extension_for_transformers/llm/library/jblas/jblas/jit_blas_weight_compression.h b/intel_extension_for_transformers/llm/library/jblas/jblas/jit_blas_weight_compression.h index 8c27ace9c8a..07e0a3a7ac0 100644 --- a/intel_extension_for_transformers/llm/library/jblas/jblas/jit_blas_weight_compression.h +++ b/intel_extension_for_transformers/llm/library/jblas/jblas/jit_blas_weight_compression.h @@ -138,10 +138,14 @@ class WeightS8ScaleFp32 { int KPad = utils::padto(k, _GemmCore_T::KTILE); int NPad = utils::padto(n, _GemmCore_T::NTILE); StorageWeight tmp(_GemmCore_T::TYPE); - tmp.resize(NPad, KPad, blocksize <= 0 ? k : blocksize, is_asym); + tmp.resize(NPad, KPad, blocksize <= 0 ? KPad : blocksize, is_asym); return tmp; } + Parallel createParallel(const int N, const int K) { + return Parallel(); // no runtime parallel forward + } + Parallel createParallel(const int N, const int K, const int blocksize) { assert(0); return Parallel(); // no runtime parallel forward @@ -154,28 +158,33 @@ class WeightS8ScaleFp32 { // from K*N fp32 weight to packed N//NtilexKPadxNTile weight virtual void packTransposeWeight(const int N, const int K, const float* B, const int ldb, void* stor) { - utils::aligned_vector B_NT(N * K); - prologue::gemm::transposeWeight(N, K, B, ldb, B_NT.data(), N); - packWeight(N, K, B_NT.data(), N, stor); + auto B_NT = utils::amalloc((size_t)N * K); + prologue::gemm::transposeWeight(N, K, B, ldb, B_NT, N); + packWeight(N, K, B_NT, N, stor); + utils::afree(B_NT); } // from packed N//NtilexKPadxNTile int8 weight to KxN f32 weight virtual void unpackTransposeWeight(const int N, const int K, void* stor, float* B, const int ldb) { - utils::aligned_vector B_NT(N * K); - unpackWeight(N, K, stor, B_NT.data(), N); - prologue::gemm::transposeWeight(K, N, B_NT.data(), N, B, ldb); + auto B_NT = utils::amalloc((size_t)N * K); + unpackWeight(N, K, stor, B_NT, N); + prologue::gemm::transposeWeight(K, N, B_NT, N, B, ldb); + utils::afree(B_NT); } // from KxN f32 weight to packed N//NtilexKPadxNTile int8 weight virtual void packWeight(const int N, const int K, const float* B, const int ldb, void* stor) { - utils::aligned_vector tmpq(N * K); + auto tmpq = utils::amalloc((size_t)N * K); auto ptr = reinterpret_cast(stor); int nk_scale = utils::updiv(K, ptr->mBlockSize); auto ssize = (size_t)N * nk_scale; - utils::avector Tscales(ssize); - utils::avector Tzps(ptr->mIsAsym ? ssize : 0); - quantizeWeight(N, K, B, ldb, ptr->mBlockSize, tmpq.data(), Tscales.data(), Tzps.data()); - packQWeight(N, K, tmpq.data(), ldb, Tscales.data(), Tzps.data(), stor); + auto Tscales = utils::amalloc(ssize); + auto Tzps = utils::amalloc(ptr->mIsAsym ? ssize : 0); + quantizeWeight(N, K, B, ldb, ptr->mBlockSize, tmpq, Tscales, Tzps); + packQWeight(N, K, tmpq, ldb, Tscales, Tzps, stor); + utils::afree(tmpq); + utils::afree(Tscales); + utils::afree(Tzps); } virtual void unpackWeight(const int N, const int K, void* stor, float* B, const int ldb) { @@ -192,14 +201,15 @@ class WeightS8ScaleFp32 { int rowremain = utils::remainsize(rowidx, K, rowsize); // rowremain: src valid size. rowsize: padded size int colremain = utils::remainsize(colidx, N, colsize); - std::vector dequant(rowsize * colsize); + auto dequant = utils::amalloc((size_t)rowsize * colsize); + auto dstptr = dequant; int dststep = 0; - auto dstptr = dequant.data(); auto rowpad = utils::padto(rowremain, _GemmCore_T::KTILE); auto colpad = utils::padto(colremain, _GemmCore_T::NTILE); getWeight(&dstptr, &dststep, rowpad, colpad, rowidx, colidx, {stor}); kernel::wrapper::RevertPaddingInterleaveMN<_GemmCore_T::NTILE, _GemmCore_T::PACK_ROW>::template forward( dstptr, B + rowidx * ldb + colidx, rowremain, colremain, rowpad, colpad, dststep, ldb); + utils::afree(dequant); } } } @@ -218,14 +228,15 @@ class WeightS8ScaleFp32 { int rowremain = utils::remainsize(rowidx, K, rowsize); // rowremain: src valid size. rowsize: padded size int colremain = utils::remainsize(colidx, N, colsize); - std::vector dequant(rowsize * colsize); + auto dequant = utils::amalloc((size_t)rowsize * colsize); int dststep = 0; - auto dstptr = dequant.data(); + auto dstptr = dequant; auto rowpad = utils::padto(rowremain, _GemmCore_T::KTILE); auto colpad = utils::padto(colremain, _GemmCore_T::NTILE); getWeight(&dstptr, &dststep, rowpad, colpad, rowidx, colidx, {stor}); kernel::wrapper::RevertPaddingInterleaveMN<_GemmCore_T::NTILE, _GemmCore_T::PACK_ROW>::template forward( dstptr, B + rowidx * ldb + colidx, rowremain, colremain, rowpad, colpad, dststep, ldb); + utils::afree(dequant); } } } @@ -252,9 +263,10 @@ class WeightS8ScaleFp32 { reorderWeight(N, K, B, ldb, stor->WPtr()); if (stor->mHasReduce) { - utils::avector deq(K * N); - unpackWeight(N, K, stor, deq.data(), N); - reduceWeight(N, K, stor->mBlockSize, deq.data(), ldb, stor->mRPtr, stor->mNPad); + auto deq = utils::amalloc((size_t)K * N); + unpackWeight(N, K, stor, deq, N); + reduceWeight(N, K, stor->mBlockSize, deq, ldb, stor->mRPtr, stor->mNPad); + utils::afree(deq); } } @@ -292,18 +304,12 @@ class WeightS8ScaleFp32 { auto KPad = wptr->mKPad; auto bptr = wptr->WPtr() + n_offset * KPad + k_offset * _GemmCore_T::NTILE; for (int i = 0; i < n_size; i += _GemmCore_T::NTILE) { - if constexpr (_GemmCore_T::PACK_ROW == 1) { - kernel::wrapper::DecompressKBlockS8F32::forward( - bptr + i * KPad, *dstptr + i * k_size, k_size / _GemmCore_T::PACK_ROW, - _GemmCore_T::NTILE * _GemmCore_T::PACK_ROW, _GemmCore_T::NTILE * _GemmCore_T::PACK_ROW, - _GemmCore_T::NTILE * _GemmCore_T::PACK_ROW, wptr->mSPtr + n_offset + i, - wptr->mZPtr != nullptr ? wptr->mZPtr + n_offset + i : nullptr, k_offset, wptr->mBlockSize, NPad); - } else { - kernel::wrapper::DecompressKBlockS8FP32PackRow::forward( - bptr + i * KPad, *dstptr + i * k_size, k_size, _GemmCore_T::NTILE, _GemmCore_T::NTILE, _GemmCore_T::NTILE, - wptr->mSPtr + n_offset + i, wptr->mZPtr != nullptr ? wptr->mZPtr + n_offset + i : nullptr, k_offset, - wptr->mBlockSize, NPad, _GemmCore_T::PACK_ROW); - } + kernel::wrapper::DecompressKBlockS8F32<_GemmCore_T::PACK_ROW>::template forward( + bptr + i * KPad, *dstptr + i * k_size, k_size / _GemmCore_T::PACK_ROW, + _GemmCore_T::NTILE * _GemmCore_T::PACK_ROW, _GemmCore_T::NTILE * _GemmCore_T::PACK_ROW, + _GemmCore_T::NTILE * _GemmCore_T::PACK_ROW, wptr->mSPtr + n_offset + i, + wptr->mZPtr != nullptr ? wptr->mZPtr + n_offset + i : nullptr, k_offset / _GemmCore_T::PACK_ROW, + wptr->mBlockSize / _GemmCore_T::PACK_ROW, NPad); } *dststep = k_size; return JblasSuccess; @@ -366,12 +372,9 @@ class WeightS8ScaleFp32 { { int tidx = omp_get_thread_num(); int colidx, rowidx, rowsize, colsize; - _para.getIndex(tidx, &rowidx, &colidx, &rowsize, &colsize); + _para.getIndex(tidx, &rowidx, &colidx, &rowsize, &colsize, false); if (rowsize > 0 && colsize > 0) { - int rowremain = utils::remainsize(rowidx, K, - rowsize); // rowremain: src valid size. rowsize: padded size - int colremain = utils::remainsize(colidx, N, colsize); - quantRowBlock(B + rowidx * ldb + colidx, qB + rowidx * N + colidx, rowremain, colremain, ldb, N, + quantRowBlock(B + rowidx * ldb + colidx, qB + rowidx * N + colidx, rowsize, colsize, ldb, N, scales + rowidx / bsize * N + colidx, zero_points == nullptr ? zero_points : zero_points + rowidx / bsize * N + colidx, bsize); } @@ -408,48 +411,22 @@ class WeightS8ScaleFp32 { class StorageWeightS8ScaleFp32PerChannelN : public StorageWeightS8ScaleFp32 { public: - using Parent = StorageWeightS8ScaleFp32; StorageWeightS8ScaleFp32PerChannelN(jblas::gemm::GemmCoreType _type) : StorageWeightS8ScaleFp32(_type) { mPrologueID = static_cast(PrologueBIDs::WeightS8ScaleFp32PerChannelN); } - - size_t resize(int NPad, int KPad, int K, bool IsAsym = true) { - return StorageWeightS8ScaleFp32::resize(NPad, KPad, K, IsAsym); - } }; template class WeightS8ScaleFp32PerChannelN : public WeightS8ScaleFp32<_GemmCore_T, ISA_T> { public: - using Parent = WeightS8ScaleFp32<_GemmCore_T, ISA_T>; - using Param = typename Parent::Param; using StorageWeight = StorageWeightS8ScaleFp32PerChannelN; - using SType = float; - using Parallel = utils::parallel::Parallel2DRowMajor; StorageWeight createStorage(const int N, const int K, bool is_asym) { int KPad = utils::padto(K, _GemmCore_T::KTILE); int NPad = utils::padto(N, _GemmCore_T::NTILE); StorageWeight tmp(_GemmCore_T::TYPE); - tmp.resize(NPad, KPad, K, is_asym); + tmp.resize(NPad, KPad, KPad, is_asym); return tmp; } - - Parallel createParallel(const int N, const int K) { - return Parallel(); // no runtime parallel forward - } - - virtual void packQWeight(const int N, const int K, const int8_t* B, const int ldb, const float* scales, - const int8_t* zero_points, void* ptr) override { - auto stor = reinterpret_cast(ptr); - std::memcpy(stor->mSPtr, scales, N * sizeof(scales[0])); - if (zero_points != nullptr) { - std::memcpy(stor->mZPtr, zero_points, N * sizeof(zero_points[0])); - } - WeightS8ScaleFp32<_GemmCore_T, ISA_T>::reorderWeight(N, K, B, ldb, stor->WPtr()); - utils::avector deq(K * N); - WeightS8ScaleFp32<_GemmCore_T, ISA_T>::unpackWeight(N, K, stor, deq.data(), N); - WeightS8ScaleFp32<_GemmCore_T, ISA_T>::reduceWeight(N, K, K, deq.data(), ldb, stor->mRPtr, stor->mNPad); - } }; class StorageWeightS4ScaleFp32 : public WeightBase, @@ -507,11 +484,12 @@ class WeightS4ScaleFp32 : public WeightS8ScaleFp32<_GemmCore_T, ISA_T> { public: using Param = typename WeightS8ScaleFp32<_GemmCore_T, ISA_T>::Param; using StorageWeight = StorageWeightS4ScaleFp32; + using ScaleType = float; StorageWeight createStorage(const int N, const int K, int blocksize, bool is_asym = false) { int KPad = utils::padto(K, _GemmCore_T::KTILE); int NPad = utils::padto(N, _GemmCore_T::NTILE); StorageWeight tmp(_GemmCore_T::TYPE, S4_T); - tmp.resize(NPad, KPad, blocksize <= 0 ? K : blocksize, is_asym); + tmp.resize(NPad, KPad, blocksize <= 0 ? KPad : blocksize, is_asym); return tmp; } @@ -540,14 +518,15 @@ class WeightS4ScaleFp32 : public WeightS8ScaleFp32<_GemmCore_T, ISA_T> { } } } - utils::avector reorded(stor->mKPad * stor->mNPad); - WeightS8ScaleFp32<_GemmCore_T, ISA_T>::reorderWeight(N, K, B, ldb, reorded.data()); - compressWeight(stor->mNPad, stor->mKPad, reorded.data(), stor->mNPad, stor->WPtr()); + auto reorded = utils::amalloc((size_t)stor->mKPad * stor->mNPad); + WeightS8ScaleFp32<_GemmCore_T, ISA_T>::reorderWeight(N, K, B, ldb, reorded); + compressWeight(stor->mNPad, stor->mKPad, reorded, stor->mNPad, stor->WPtr()); + utils::afree(reorded); if (stor->mHasReduce) { - utils::avector deq(K * N); - WeightS8ScaleFp32<_GemmCore_T, ISA_T>::unpackWeight(N, K, stor, deq.data(), N); - WeightS8ScaleFp32<_GemmCore_T, ISA_T>::reduceWeight(N, K, stor->mBlockSize, deq.data(), ldb, stor->mRPtr, - stor->mNPad); + auto deq = utils::amalloc((size_t)K * N); + WeightS8ScaleFp32<_GemmCore_T, ISA_T>::unpackWeight(N, K, stor, deq, N); + WeightS8ScaleFp32<_GemmCore_T, ISA_T>::reduceWeight(N, K, stor->mBlockSize, deq, ldb, stor->mRPtr, stor->mNPad); + utils::afree(deq); } } @@ -579,7 +558,7 @@ class WeightS4ScaleFp32 : public WeightS8ScaleFp32<_GemmCore_T, ISA_T> { auto KPad = wptr->mKPad; auto bptr = wptr->WPtr() + n_offset * KPad / 2 + k_offset * _GemmCore_T::NTILE / 2; for (int i = 0; i < n_size; i += _GemmCore_T::NTILE) { - kernel::wrapper::DecompressKBlockS4S8::forward( + kernel::wrapper::DecompressKBlockS4S8::template forward( (utils::int4x2*)(bptr + i * KPad / 2), *dstptr + i * k_size, k_size / _GemmCore_T::PACK_ROW, _GemmCore_T::NTILE * _GemmCore_T::PACK_ROW, _GemmCore_T::NTILE * _GemmCore_T::PACK_ROW, _GemmCore_T::NTILE * _GemmCore_T::PACK_ROW); @@ -590,45 +569,12 @@ class WeightS4ScaleFp32 : public WeightS8ScaleFp32<_GemmCore_T, ISA_T> { virtual inline JBLAS_CODE getWeight(float** dstptr, int* dststep, int k_size, int n_size, int k_offset, int n_offset, const Param& _param) override { - auto wptr = (StorageWeight*)(_param.packedW); - auto NPad = wptr->mNPad; - auto KPad = wptr->mKPad; - auto bptr = wptr->WPtr() + n_offset * KPad / 2 + k_offset * _GemmCore_T::NTILE / 2; - for (int i = 0; i < n_size; i += _GemmCore_T::NTILE) { - if constexpr (_GemmCore_T::PACK_ROW == 1) { - kernel::wrapper::DecompressKBlockS4FP::forward( - (utils::int4x2*)(bptr + i * KPad / 2), *dstptr + i * k_size, k_size / _GemmCore_T::PACK_ROW, - _GemmCore_T::NTILE * _GemmCore_T::PACK_ROW, _GemmCore_T::NTILE * _GemmCore_T::PACK_ROW, - _GemmCore_T::NTILE * _GemmCore_T::PACK_ROW, wptr->mSPtr + n_offset + i, - wptr->mZPtr != nullptr ? wptr->mZPtr + n_offset + i : nullptr, k_offset, wptr->mBlockSize, NPad); - } else { - kernel::wrapper::DecompressKBlockS4FPPackRow::forward( - (utils::int4x2*)(bptr + i * KPad / 2), *dstptr + i * k_size, k_size, _GemmCore_T::NTILE, _GemmCore_T::NTILE, - _GemmCore_T::NTILE, wptr->mSPtr + n_offset + i, - wptr->mZPtr != nullptr ? wptr->mZPtr + n_offset + i : nullptr, k_offset, wptr->mBlockSize, NPad, - _GemmCore_T::PACK_ROW); - } - } - *dststep = k_size; - return JblasSuccess; + return getFpWeight(dstptr, dststep, k_size, n_size, k_offset, n_offset, _param); } virtual inline JBLAS_CODE getWeight(utils::bf16** dstptr, int* dststep, int k_size, int n_size, int k_offset, int n_offset, const Param& _param) { - auto wptr = (StorageWeight*)(_param.packedW); - auto NPad = wptr->mNPad; - auto KPad = wptr->mKPad; - auto bptr = wptr->WPtr() + n_offset * KPad / 2 + k_offset * _GemmCore_T::NTILE / 2; - for (int i = 0; i < n_size; i += _GemmCore_T::NTILE) { - kernel::wrapper::DecompressKBlockS4FP::forward( - (utils::int4x2*)(bptr + i * KPad / 2), *dstptr + i * k_size, k_size / _GemmCore_T::PACK_ROW, - _GemmCore_T::NTILE * _GemmCore_T::PACK_ROW, _GemmCore_T::NTILE * _GemmCore_T::PACK_ROW, - _GemmCore_T::NTILE * _GemmCore_T::PACK_ROW, wptr->mSPtr + n_offset + i, - wptr->mZPtr != nullptr ? wptr->mZPtr + n_offset + i : nullptr, k_offset / _GemmCore_T::PACK_ROW, - wptr->mBlockSize / _GemmCore_T::PACK_ROW, NPad); - } - *dststep = k_size; - return JblasSuccess; + return getFpWeight(dstptr, dststep, k_size, n_size, k_offset, n_offset, _param); } virtual JBLAS_CODE getScale(float** dstptr, int* dststep, int n_size, int k_size, int n_offset, int k_offset, @@ -656,6 +602,25 @@ class WeightS4ScaleFp32 : public WeightS8ScaleFp32<_GemmCore_T, ISA_T> { } protected: + template + inline JBLAS_CODE getFpWeight(_T** dstptr, int* dststep, int k_size, int n_size, int k_offset, int n_offset, + const Param& _param) { + auto wptr = (StorageWeight*)(_param.packedW); + auto NPad = wptr->mNPad; + auto KPad = wptr->mKPad; + auto bptr = wptr->WPtr() + n_offset * KPad / 2 + k_offset * _GemmCore_T::NTILE / 2; + for (int i = 0; i < n_size; i += _GemmCore_T::NTILE) { + kernel::wrapper::DecompressKBlockS4FP<_T, _GemmCore_T::PACK_ROW>::template forward( + (utils::int4x2*)(bptr + i * KPad / 2), *dstptr + i * k_size, k_size / _GemmCore_T::PACK_ROW, + _GemmCore_T::NTILE * _GemmCore_T::PACK_ROW, _GemmCore_T::NTILE * _GemmCore_T::PACK_ROW, + _GemmCore_T::NTILE * _GemmCore_T::PACK_ROW, wptr->mSPtr + n_offset + i, + wptr->mZPtr != nullptr ? wptr->mZPtr + n_offset + i : nullptr, k_offset / _GemmCore_T::PACK_ROW, + wptr->mBlockSize / _GemmCore_T::PACK_ROW, NPad); + } + *dststep = k_size; + return JblasSuccess; + } + virtual JBLAS_CODE doCompress(const int8_t* srcptr, void* dstptr, int row, int col, int ld_src, int ld_dst) { return kernel::wrapper::CompressS8S4<_GemmCore_T::NTILE>::template forward( srcptr, reinterpret_cast(dstptr), row, col, ld_src, @@ -677,115 +642,21 @@ class StorageWeightS4ScaleFp32PerChannelN : public StorageWeightS4ScaleFp32 { break; } } - - size_t resize(int NPad, int KPad, int K, bool IsAsym = true) { - return StorageWeightS4ScaleFp32::resize(NPad, KPad, K, IsAsym); - } }; template -class WeightS4ScaleFp32PerChannelN : public WeightS8ScaleFp32PerChannelN<_GemmCore_T, ISA_T> { +class WeightS4ScaleFp32PerChannelN : public WeightS4ScaleFp32<_GemmCore_T, ISA_T, S4_T> { public: - using Parent = WeightS8ScaleFp32PerChannelN<_GemmCore_T, ISA_T>; + using Parent = WeightS4ScaleFp32<_GemmCore_T, ISA_T, S4_T>; using Param = typename Parent::Param; using StorageWeight = StorageWeightS4ScaleFp32PerChannelN; StorageWeight createStorage(const int N, const int K, bool is_asym) { int KPad = utils::padto(K, _GemmCore_T::KTILE); int NPad = utils::padto(N, _GemmCore_T::NTILE); StorageWeight tmp(_GemmCore_T::TYPE, S4_T); - tmp.resize(NPad, KPad, K, is_asym); + tmp.resize(NPad, KPad, KPad, is_asym); return tmp; } - - virtual void quantRowBlock(const float* srcptr, int8_t* dstptr, int row, int col, int ld_src, int ld_dst, - float* scales, int8_t* zero_points, int blocksize) { - kernel::wrapper::QuantizeSignIntRowBlock::forward(srcptr, dstptr, row, col, ld_src, ld_dst, scales, - zero_points, blocksize); - } - - virtual void packQWeight(const int N, const int K, const int8_t* B, const int ldb, const float* scales, - const int8_t* zero_points, void* ptr) override { - auto stor = reinterpret_cast(ptr); - std::memcpy(stor->mSPtr, scales, N * sizeof(scales[0])); - if (zero_points != nullptr) { - std::memcpy(stor->mZPtr, zero_points, N * sizeof(zero_points[0])); - } - utils::avector reorded(stor->mKPad * stor->mNPad); - WeightS8ScaleFp32<_GemmCore_T, ISA_T>::reorderWeight(N, K, B, ldb, reorded.data()); - compressWeight(stor->mNPad, stor->mKPad, reorded.data(), stor->mNPad, stor->WPtr()); - utils::avector deq(K * N); - WeightS8ScaleFp32<_GemmCore_T, ISA_T>::unpackWeight(N, K, stor, deq.data(), N); - WeightS8ScaleFp32<_GemmCore_T, ISA_T>::reduceWeight(N, K, K, deq.data(), ldb, stor->mRPtr, stor->mNPad); - } - - void compressWeight(const int N, const int K, const int8_t* B, const int ldb, utils::bit4x2* dstptr) { - utils::parallel::Parallel2DRowMajor _para; - utils::CpuBase cb; - _para.update(K, N, _GemmCore_T::KTILE, _GemmCore_T::NTILE, cb.mNumThreads); - omp_set_num_threads(cb.mNumThreads); -#pragma omp parallel - { - int tidx = omp_get_thread_num(); - int colidx, rowidx, rowsize, colsize; - _para.getIndex(tidx, &rowidx, &colidx, &rowsize, &colsize); - if (rowsize > 0 && colsize > 0) { - int rowremain = utils::remainsize(rowidx, K, - rowsize); // rowremain: src valid size. rowsize: padded size - int colremain = utils::remainsize(colidx, N, colsize); - auto ret = doCompress(B + rowidx * ldb + colidx, dstptr + rowidx * ldb / 2 + colidx / 2, rowremain, colremain, - ldb, ldb); - assert(ret == JblasSuccess); - (void)ret; - } - } - } - - virtual inline JBLAS_CODE getWeight(int8_t** dstptr, int* dststep, int k_size, int n_size, int k_offset, int n_offset, - const Param& _param) override { - auto wptr = (StorageWeight*)(_param.packedW); - auto KPad = wptr->mKPad; - auto bptr = wptr->WPtr() + n_offset * KPad / 2 + k_offset * _GemmCore_T::NTILE / 2; - for (int i = 0; i < n_size; i += _GemmCore_T::NTILE) { - kernel::wrapper::DecompressKBlockS4S8::forward( - (utils::int4x2*)(bptr + i * KPad / 2), *dstptr + i * k_size, k_size / _GemmCore_T::PACK_ROW, - _GemmCore_T::NTILE * _GemmCore_T::PACK_ROW, _GemmCore_T::NTILE * _GemmCore_T::PACK_ROW, - _GemmCore_T::NTILE * _GemmCore_T::PACK_ROW); - } - *dststep = k_size; - return JblasSuccess; - } - - virtual inline JBLAS_CODE getWeight(float** dstptr, int* dststep, int k_size, int n_size, int k_offset, int n_offset, - const Param& _param) override { - auto wptr = (StorageWeight*)(_param.packedW); - auto NPad = wptr->mNPad; - auto KPad = wptr->mKPad; - auto bptr = wptr->WPtr() + n_offset * KPad / 2 + k_offset * _GemmCore_T::NTILE / 2; - for (int i = 0; i < n_size; i += _GemmCore_T::NTILE) { - if constexpr (_GemmCore_T::PACK_ROW == 1) { - kernel::wrapper::DecompressPerNS4FP::forward( - (utils::int4x2*)(bptr + i * KPad / 2), *dstptr + i * k_size, k_size / _GemmCore_T::PACK_ROW, - _GemmCore_T::NTILE * _GemmCore_T::PACK_ROW, _GemmCore_T::NTILE * _GemmCore_T::PACK_ROW, - _GemmCore_T::NTILE * _GemmCore_T::PACK_ROW, wptr->mSPtr + n_offset + i, - wptr->mZPtr != nullptr ? wptr->mZPtr + n_offset + i : nullptr, k_offset, wptr->mBlockSize, NPad); - } else { - kernel::wrapper::DecompressPerNS4FPPackRow::forward( - (utils::int4x2*)(bptr + i * KPad / 2), *dstptr + i * k_size, k_size, _GemmCore_T::NTILE, _GemmCore_T::NTILE, - _GemmCore_T::NTILE, wptr->mSPtr + n_offset + i, - wptr->mZPtr != nullptr ? wptr->mZPtr + n_offset + i : nullptr, k_offset, wptr->mBlockSize, NPad, - _GemmCore_T::PACK_ROW); - } - } - *dststep = k_size; - return JblasSuccess; - } - - protected: - virtual JBLAS_CODE doCompress(const int8_t* srcptr, void* dstptr, int row, int col, int ld_src, int ld_dst) { - return kernel::wrapper::CompressS8S4<_GemmCore_T::NTILE>::template forward( - srcptr, reinterpret_cast(dstptr), row, col, ld_src, - ld_dst); // ld_dst here not stride - } }; template @@ -850,7 +721,7 @@ class WeightS4ScaleBf16 : public WeightS4ScaleFp32<_GemmCore_T, ISA_T, S4_T> { int KPad = utils::padto(K, _GemmCore_T::KTILE); int NPad = utils::padto(N, _GemmCore_T::NTILE); StorageWeight tmp(_GemmCore_T::TYPE, S4_T); - tmp.resize(NPad, KPad, blocksize <= 0 ? K : blocksize, is_asym); + tmp.resize(NPad, KPad, blocksize <= 0 ? KPad : blocksize, is_asym); return tmp; } @@ -875,10 +746,11 @@ class WeightS4ScaleBf16 : public WeightS4ScaleFp32<_GemmCore_T, ISA_T, S4_T> { } } } - utils::avector reorded(stor->mKPad * stor->mNPad); - WeightS8ScaleFp32<_GemmCore_T, ISA_T>::reorderWeight(N, K, B, ldb, reorded.data()); - WeightS4ScaleFp32<_GemmCore_T, ISA_T, S4_T>::compressWeight(stor->mNPad, stor->mKPad, reorded.data(), stor->mNPad, + auto reorded = utils::amalloc((size_t)stor->mKPad * stor->mNPad); + WeightS8ScaleFp32<_GemmCore_T, ISA_T>::reorderWeight(N, K, B, ldb, reorded); + WeightS4ScaleFp32<_GemmCore_T, ISA_T, S4_T>::compressWeight(stor->mNPad, stor->mKPad, reorded, stor->mNPad, stor->WPtr()); + utils::afree(reorded); } virtual inline JBLAS_CODE getScale(utils::bf16** dstptr, int* dststep, int n_size, int k_size, int n_offset, @@ -930,7 +802,7 @@ class WeightF4ScaleFp32 : public WeightS4ScaleFp32<_GemmCore_T, ISA_T, S4_CLIP> int KPad = utils::padto(K, _GemmCore_T::KTILE); int NPad = utils::padto(N, _GemmCore_T::NTILE); StorageWeight tmp(_GemmCore_T::TYPE, F4_T); - tmp.resize(NPad, KPad, blocksize <= 0 ? K : blocksize); + tmp.resize(NPad, KPad, blocksize <= 0 ? KPad : blocksize); return tmp; } @@ -946,10 +818,11 @@ class WeightF4ScaleFp32 : public WeightS4ScaleFp32<_GemmCore_T, ISA_T, S4_CLIP> std::memset(stor->mSPtr + i * stor->mNPad, 0, stor->mNPad * sizeof(stor->mSPtr[0])); } } - utils::avector reorded(stor->mKPad * stor->mNPad); - WeightS8ScaleFp32<_GemmCore_T, ISA_T>::reorderWeight(N, K, B, ldb, reorded.data()); - WeightS4ScaleFp32<_GemmCore_T, ISA_T, S4_CLIP>::compressWeight(stor->mNPad, stor->mKPad, reorded.data(), - stor->mNPad, stor->WPtr()); + auto reorded = utils::amalloc((size_t)stor->mKPad * stor->mNPad); + WeightS8ScaleFp32<_GemmCore_T, ISA_T>::reorderWeight(N, K, B, ldb, reorded); + WeightS4ScaleFp32<_GemmCore_T, ISA_T, S4_CLIP>::compressWeight(stor->mNPad, stor->mKPad, reorded, stor->mNPad, + stor->WPtr()); + utils::afree(reorded); } virtual inline JBLAS_CODE getWeight(float** dstptr, int* dststep, int k_size, int n_size, int k_offset, int n_offset, @@ -959,16 +832,11 @@ class WeightF4ScaleFp32 : public WeightS4ScaleFp32<_GemmCore_T, ISA_T, S4_CLIP> auto KPad = wptr->mKPad; auto bptr = wptr->WPtr() + n_offset * KPad / 2 + k_offset * _GemmCore_T::NTILE / 2; for (int i = 0; i < n_size; i += _GemmCore_T::NTILE) { - if constexpr (_GemmCore_T::PACK_ROW == 1) { - kernel::wrapper::DecompressKBlockF4Fp::forward( - reinterpret_cast(bptr + i * KPad / 2), *dstptr + i * k_size, k_size / _GemmCore_T::PACK_ROW, - _GemmCore_T::NTILE * _GemmCore_T::PACK_ROW, _GemmCore_T::NTILE * _GemmCore_T::PACK_ROW, - _GemmCore_T::NTILE * _GemmCore_T::PACK_ROW, wptr->mSPtr + n_offset + i, k_offset, wptr->mBlockSize, NPad); - } else { - kernel::wrapper::DecompressKBlockF4FPPackRow::forward( - (utils::f4x2*)(bptr + i * KPad / 2), *dstptr + i * k_size, k_size, _GemmCore_T::NTILE, _GemmCore_T::NTILE, - _GemmCore_T::NTILE, wptr->mSPtr + n_offset + i, k_offset, wptr->mBlockSize, NPad, _GemmCore_T::PACK_ROW); - } + kernel::wrapper::DecompressKBlockF4Fp::template forward( + reinterpret_cast(bptr + i * KPad / 2), *dstptr + i * k_size, k_size / _GemmCore_T::PACK_ROW, + _GemmCore_T::NTILE * _GemmCore_T::PACK_ROW, _GemmCore_T::NTILE * _GemmCore_T::PACK_ROW, + _GemmCore_T::NTILE * _GemmCore_T::PACK_ROW, wptr->mSPtr + n_offset + i, k_offset / _GemmCore_T::PACK_ROW, + wptr->mBlockSize / _GemmCore_T::PACK_ROW, NPad); } *dststep = k_size; return JblasSuccess; @@ -981,7 +849,7 @@ class WeightF4ScaleFp32 : public WeightS4ScaleFp32<_GemmCore_T, ISA_T, S4_CLIP> auto KPad = wptr->mKPad; auto bptr = wptr->WPtr() + n_offset * KPad / 2 + k_offset * _GemmCore_T::NTILE / 2; for (int i = 0; i < n_size; i += _GemmCore_T::NTILE) { - kernel::wrapper::DecompressKBlockF4Fp::forward( + kernel::wrapper::DecompressKBlockF4Fp::template forward( reinterpret_cast(bptr + i * KPad / 2), *dstptr + i * k_size, k_size / _GemmCore_T::PACK_ROW, _GemmCore_T::NTILE * _GemmCore_T::PACK_ROW, _GemmCore_T::NTILE * _GemmCore_T::PACK_ROW, _GemmCore_T::NTILE * _GemmCore_T::PACK_ROW, wptr->mSPtr + n_offset + i, k_offset / _GemmCore_T::PACK_ROW, @@ -1327,8 +1195,8 @@ class GemmInterfaceKblockParallelAB { template JBLAS_CODE compute(const Arguments& _param) { auto bptr = (prologue::weight_comp::gemm_kblcok::WeightBase*)(_param.paramB.packedW); - auto paraA = getActivationPtr()->createParallel(_param.M, _param.K, _param.KBlock); - auto paraB = getWeightPtr()->createParallel(_param.K, _param.N, _param.KBlock); + auto paraA = getActivationPtr()->createParallel(_param.M, _param.K, bptr->mBlockSize); + auto paraB = getWeightPtr()->createParallel(_param.K, _param.N, bptr->mBlockSize); auto para = Parallel(); auto cb = utils::CpuBase(); if (para.update(_param.M, _param.N, _param.K, bptr->mBlockSize, cb.mNumThreads)) { diff --git a/intel_extension_for_transformers/llm/library/jblas/jblas/jit_blas_wrapper.h b/intel_extension_for_transformers/llm/library/jblas/jblas/jit_blas_wrapper.h index 59b4f070b9d..8f340480ee8 100644 --- a/intel_extension_for_transformers/llm/library/jblas/jblas/jit_blas_wrapper.h +++ b/intel_extension_for_transformers/llm/library/jblas/jblas/jit_blas_wrapper.h @@ -215,23 +215,23 @@ JBLAS_ISA constexpr DefaultISA = JblasAVX2; using GemmKernel = jblas::wrapper::gemm_pack_weight::GemmInterfacePackWeight< jblas::wrapper::gemm_pack_weight::GemmLauncherPackWeight< // DefaultISA, // - jblas::gemm::GemmCore_Row_NN_4x24_AVX2, // + jblas::gemm::GemmCore_Row_NN_4x24_AVX2, // jblas::prologue::gemm::ActivationBase, // jblas::prologue::gemm::WeightPack, // jblas::epilogue::gemm::AccumulatorWriteBackFp32>, DefaultParallel>; -} +} // namespace avx2 namespace avx_vnni { JBLAS_ISA constexpr DefaultISA = JblasAVX_VNNI; using GemmKernel48 = jblas::wrapper::gemm_pack_weight::GemmInterfacePackWeight< jblas::wrapper::gemm_pack_weight::GemmLauncherPackWeight< // DefaultISA, // - jblas::gemm::GemmCore_Row_NN_2x48_AVX_VNNI, // + jblas::gemm::GemmCore_Row_NN_2x48_AVX_VNNI, // jblas::prologue::gemm::ActivationBase, // jblas::prologue::gemm::WeightPack, // jblas::epilogue::gemm::AlphaBetaProcessS32U8>, DefaultParallel>; -} +} // namespace avx_vnni namespace avx512f { JBLAS_ISA constexpr DefaultISA = JblasAVX512F; using GemmKernel = jblas::wrapper::gemm_pack_weight::GemmInterfacePackWeight< diff --git a/intel_extension_for_transformers/llm/library/jblas/jblas/kernel_avx2.h b/intel_extension_for_transformers/llm/library/jblas/jblas/kernel_avx2.h index fc303e6ae8f..e3cc5b0ac3f 100644 --- a/intel_extension_for_transformers/llm/library/jblas/jblas/kernel_avx2.h +++ b/intel_extension_for_transformers/llm/library/jblas/jblas/kernel_avx2.h @@ -200,11 +200,12 @@ static inline void dequant_f4_N(_DST_T* dstptr, int8_t* srcptr, __m256* vscales) } } -template -static inline JBLAS_CODE decompress_kblock_bit4_fp32(utils::bit4x2* srcptr, float* dstptr, int row, int col, int ld_src, - int ld_dst, _ST* scales, int8_t* zero_points, int k_offset, - int kblock, int NPad, void (*dequantize)(float*, int8_t*, __m256*), - void (*pad_bit4)(int8_t*, int8_t*)) { +template +static inline JBLAS_CODE decompress_kblock_bit4_packrow1(utils::bit4x2* srcptr, _DST_T* dstptr, int row, int col, + int ld_src, int ld_dst, _ST* scales, int8_t* zero_points, + int k_offset, int kblock, int NPad, + void (*dequantize)(_DST_T*, int8_t*, __m256*), + void (*pad_bit4)(int8_t*, int8_t*)) { uint32_t mask = 0xf0f0f0f0; auto vmask = _mm256_set1_epi32(*(int*)&mask); if (col == 48) { @@ -262,24 +263,26 @@ static inline JBLAS_CODE decompress_kblock_bit4_fp32(utils::bit4x2* srcptr, floa return JblasNotSupport; } -template -static inline JBLAS_CODE decompress_kblock_bit4_bf16(utils::bit4x2* srcptr, utils::bf16* dstptr, int row, int col, - int ld_src, int ld_dst, _ST* scales, int8_t* zero_points, - int k_offset, int kblock, int NPad, - void (*dequantize)(utils::bf16*, int8_t*, __m256*), - void (*pad_bit4)(int8_t*, int8_t*)) { +template +static inline JBLAS_CODE decompress_kblock_bit4_packrow2(utils::bit4x2* srcptr, _DST_T* dstptr, int row, int col, + int ld_src, int ld_dst, _ST* scales, int8_t* zero_points, + int k_offset, int kblock, int NPad, + void (*dequantize)(_DST_T*, int8_t*, __m256*), + void (*pad_bit4)(int8_t*, int8_t*)) { return JblasNotSupport; } -template +template static inline JBLAS_CODE decompress_kblock_f4_fp(utils::f4x2* srcptr, _DST_T* dstptr, int row, int col, int ld_src, int ld_dst, _ST* scales, int k_offset, int kblock, int NPad) { - if constexpr (std::is_same<_DST_T, float>::value) { - return decompress_kblock_bit4_fp32<_ST>(srcptr, (float*)dstptr, row, col, ld_src, ld_dst, scales, nullptr, k_offset, - kblock, NPad, &dequant_f4_N<48, float, F4_T>, fp4_pad_4bit); - } else if constexpr (std::is_same<_DST_T, utils::bf16>::value) { - return decompress_kblock_bit4_bf16<_ST>(srcptr, (utils::bf16*)dstptr, row, col, ld_src, ld_dst, scales, nullptr, - k_offset, kblock, NPad, &dequant_f4_N<64, utils::bf16, F4_T>, fp4_pad_4bit); + if constexpr (_PACK_ROW == 1) { + return decompress_kblock_bit4_packrow1<_ST, _DST_T>(srcptr, dstptr, row, col, ld_src, ld_dst, scales, nullptr, + k_offset, kblock, NPad, &dequant_f4_N<48, _DST_T, _F4_T>, + fp4_pad_4bit); + } else if constexpr (_PACK_ROW == 2) { + return decompress_kblock_bit4_packrow2<_ST, _DST_T>(srcptr, dstptr, row, col, ld_src, ld_dst, scales, nullptr, + k_offset, kblock, NPad, &dequant_f4_N<64, _DST_T, _F4_T>, + fp4_pad_4bit); } return JblasNotSupport; } @@ -361,7 +364,7 @@ static inline JBLAS_CODE quantize_fp_u8_colblock(int row, int col, const SRC_T* for (; ij < blocksize; ij++) { auto srcval = (float)srcptr[(j + ij) + i * ld_src]; srcval = srcval * rscale; - auto srcint = int(srcval + 0.5f) + zp; + auto srcint = int(roundf(srcval)) + zp; srcint = std::min(srcint, 0xff); srcint = std::max(srcint, 0); dstptr[(j + ij) + i * ld_dst] = static_cast(srcint); diff --git a/intel_extension_for_transformers/llm/library/jblas/jblas/kernel_avx512f.h b/intel_extension_for_transformers/llm/library/jblas/jblas/kernel_avx512f.h index ebe9ea03258..0b1a18b4d10 100644 --- a/intel_extension_for_transformers/llm/library/jblas/jblas/kernel_avx512f.h +++ b/intel_extension_for_transformers/llm/library/jblas/jblas/kernel_avx512f.h @@ -117,7 +117,7 @@ static inline void dequant_f4_N(_DST_T* dstptr, int8_t* srcptr, __m512* vscales, #if CompileBF16() auto bf16_v = (__m256i)_mm512_cvtneps_pbh(fzmm); #else - auto bf16_v = _mm512_cvtepi32_epi16(_mm512_bsrli_epi128(_mm512_castps_si512(fzmm), 2));//TODO cvt with LSB + auto bf16_v = _mm512_cvtepi32_epi16(_mm512_bsrli_epi128(_mm512_castps_si512(fzmm), 2)); // TODO cvt with LSB #endif _mm256_storeu_si256((__m256i*)(dstptr + iv * 16), bf16_v); } else { @@ -160,22 +160,24 @@ static inline void vec_broadcast_epi32_2_4(__m512i* dst4regs, __m512i* src2regs) vec_broadcast_epi32_1_2(dst4regs + 2, src2regs + 1); } -template -static inline JBLAS_CODE decompress_kblock_bit4_fp32(utils::bit4x2* srcptr, float* dstptr, int row, int col, int ld_src, - int ld_dst, _ST* scales, int8_t* zero_points, int k_offset, - int kblock, int NPad, - void (*dequantize)(float*, int8_t*, __m512*, __m128i*), - void (*pad_bit4)(int8_t*, int8_t*, __m512i, int)) { +template +static inline JBLAS_CODE decompress_kblock_bit4_packrow1(utils::bit4x2* srcptr, _DT* dstptr, int row, int col, + int ld_src, int ld_dst, _ST* scales, int8_t* zero_points, + int k_offset, int kblock, int NPad, + void (*dequantize)(_DT*, int8_t*, __m512*, __m128i*), + void (*pad_bit4)(int8_t*, int8_t*, __m512i, int)) { uint32_t mask = 0xf0f0f0f0; auto zmm_mask = _mm512_set1_epi32(*(int*)&mask); if (col == 48) { + constexpr int ColTile = 48; + constexpr int NRegs = ColTile / 16; constexpr int LoadMask64 = (1 << (64 / 8)) - 1; constexpr int LoadMask48 = (1 << (48 / 8)) - 1; - __m512 vscales[3]; - __m128i vzps[3]; + __m512 vscales[NRegs]; + __m128i vzps[NRegs]; int constexpr UnrollRow = 4; - int constexpr Loop64 = 48 * UnrollRow / 64; - int8_t tmpbuf[48 * UnrollRow]; + int constexpr Loop64 = ColTile * UnrollRow / 64; + int8_t tmpbuf[ColTile * UnrollRow]; int row0 = kblock - k_offset % kblock; row0 = row0 == kblock ? 0 : row0; row0 = row0 > row ? row : row0; @@ -194,9 +196,9 @@ static inline JBLAS_CODE decompress_kblock_bit4_fp32(utils::bit4x2* srcptr, floa } for (int iterr = 0; iterr < UnrollRow; iterr++) { if constexpr (_IS_SYM) { - dequantize(dstptr + (irow + iterr) * ld_dst, tmpbuf + iterr * 48, vscales, nullptr); + dequantize(dstptr + (irow + iterr) * ld_dst, tmpbuf + iterr * ColTile, vscales, nullptr); } else { - dequantize(dstptr + (irow + iterr) * ld_dst, tmpbuf + iterr * 48, vscales, vzps); + dequantize(dstptr + (irow + iterr) * ld_dst, tmpbuf + iterr * ColTile, vscales, vzps); } } } @@ -228,9 +230,9 @@ static inline JBLAS_CODE decompress_kblock_bit4_fp32(utils::bit4x2* srcptr, floa } for (int iterr = 0; iterr < UnrollRow; iterr++) { if constexpr (_IS_SYM) { - dequantize(dstptr + (irow + irr + iterr) * ld_dst, tmpbuf + iterr * 48, vscales, nullptr); + dequantize(dstptr + (irow + irr + iterr) * ld_dst, tmpbuf + iterr * ColTile, vscales, nullptr); } else { - dequantize(dstptr + (irow + irr + iterr) * ld_dst, tmpbuf + iterr * 48, vscales, vzps); + dequantize(dstptr + (irow + irr + iterr) * ld_dst, tmpbuf + iterr * ColTile, vscales, vzps); } } } @@ -255,22 +257,24 @@ static inline JBLAS_CODE decompress_kblock_bit4_fp32(utils::bit4x2* srcptr, floa return JblasNotSupport; } -template -static inline JBLAS_CODE decompress_kblock_bit4_bf16(utils::bit4x2* srcptr, utils::bf16* dstptr, int row, int col, - int ld_src, int ld_dst, _ST* scales, int8_t* zero_points, - int k_offset, int kblock, int NPad, - void (*dequantize)(utils::bf16*, int8_t*, __m512*, __m128i*), - void (*pad_bit4)(int8_t*, int8_t*, __m512i, int)) { +template +static inline JBLAS_CODE decompress_kblock_bit4_packrow2(utils::bit4x2* srcptr, _DT* dstptr, int row, int col, + int ld_src, int ld_dst, _ST* scales, int8_t* zero_points, + int k_offset, int kblock, int NPad, + void (*dequantize)(_DT*, int8_t*, __m512*, __m128i*), + void (*pad_bit4)(int8_t*, int8_t*, __m512i, int)) { uint32_t mask = 0xf0f0f0f0; auto zmm_mask = _mm512_set1_epi32(*(int*)&mask); auto broadcast_idx = _mm512_setr_epi64(0, 4, 1, 5, 2, 6, 3, 7); auto broadcast_idx_128 = _mm_setr_epi16(0, 1, 2, 3, 4, 5, 6, 7); if (col % 64 == 0) { + constexpr int ColTile = 64; + constexpr int NRegs = ColTile / 16; constexpr int LoadMask64 = (1 << (64 / 8)) - 1; - for (int icol = 0; icol < col; icol += 64) { - __m512 vscales[4]; - __m128i vzps[4]; - int8_t tmpbuf[64]; + for (int icol = 0; icol < col; icol += ColTile) { + __m512 vscales[NRegs]; + __m128i vzps[NRegs]; + int8_t tmpbuf[ColTile]; int row0 = kblock - k_offset % kblock; row0 = row0 == kblock ? 0 : row0; row0 = row0 > row ? row : row0; @@ -344,44 +348,45 @@ static inline JBLAS_CODE decompress_kblock_bit4_bf16(utils::bit4x2* srcptr, util return JblasNotSupport; } -template +template static inline JBLAS_CODE decompress_kblock_s4_fp(utils::int4x2* srcptr, _DST_T* dstptr, int row, int col, int ld_src, int ld_dst, _ST* scales, int8_t* zero_points, int k_offset, int kblock, int NPad) { - if constexpr (std::is_same<_DST_T, float>::value) { + if constexpr (_PACK_ROW == 1) { if (zero_points == nullptr) { - return decompress_kblock_bit4_fp32<_ST, true>(srcptr, (float*)dstptr, row, col, ld_src, ld_dst, scales, - zero_points, k_offset, kblock, NPad, &dequant_s8_N<48, float, true>, - &convert_s4_s8); + return decompress_kblock_bit4_packrow1<_ST, _DST_T, true>(srcptr, dstptr, row, col, ld_src, ld_dst, scales, + zero_points, k_offset, kblock, NPad, + &dequant_s8_N<48, _DST_T, true>, &convert_s4_s8); } else { - return decompress_kblock_bit4_fp32<_ST, false>(srcptr, (float*)dstptr, row, col, ld_src, ld_dst, scales, - zero_points, k_offset, kblock, NPad, - &dequant_s8_N<48, float, false>, &convert_s4_s8); + return decompress_kblock_bit4_packrow1<_ST, _DST_T, false>( + srcptr, dstptr, row, col, ld_src, ld_dst, scales, zero_points, k_offset, kblock, NPad, + &dequant_s8_N<48, _DST_T, false>, &convert_s4_s8); } - } else if constexpr (std::is_same<_DST_T, utils::bf16>::value) { + } else if constexpr (_PACK_ROW == 2) { if (zero_points == nullptr) { - return decompress_kblock_bit4_bf16<_ST, true>(srcptr, dstptr, row, col, ld_src, ld_dst, scales, zero_points, - k_offset, kblock, NPad, &dequant_s8_N<64, utils::bf16, true>, - &convert_s4_s8); + return decompress_kblock_bit4_packrow2<_ST, _DST_T, true>(srcptr, dstptr, row, col, ld_src, ld_dst, scales, + zero_points, k_offset, kblock, NPad, + &dequant_s8_N<64, _DST_T, true>, &convert_s4_s8); } else { - return decompress_kblock_bit4_bf16<_ST, false>(srcptr, dstptr, row, col, ld_src, ld_dst, scales, zero_points, - k_offset, kblock, NPad, &dequant_s8_N<64, utils::bf16, false>, - &convert_s4_s8); + return decompress_kblock_bit4_packrow2<_ST, _DST_T, false>( + srcptr, dstptr, row, col, ld_src, ld_dst, scales, zero_points, k_offset, kblock, NPad, + &dequant_s8_N<64, _DST_T, false>, &convert_s4_s8); } } return JblasNotSupport; } -template +template static inline JBLAS_CODE decompress_kblock_f4_fp(utils::f4x2* srcptr, _DST_T* dstptr, int row, int col, int ld_src, int ld_dst, _ST* scales, int k_offset, int kblock, int NPad) { - if constexpr (std::is_same<_DST_T, float>::value) { - return decompress_kblock_bit4_fp32<_ST, true>(srcptr, (float*)dstptr, row, col, ld_src, ld_dst, scales, nullptr, - k_offset, kblock, NPad, &dequant_f4_N<48, float, F4_T>, pad_fp4); - } else if constexpr (std::is_same<_DST_T, utils::bf16>::value) { - return decompress_kblock_bit4_bf16<_ST, true>(srcptr, (utils::bf16*)dstptr, row, col, ld_src, ld_dst, scales, - nullptr, k_offset, kblock, NPad, &dequant_f4_N<64, utils::bf16, F4_T>, - pad_fp4); + if constexpr (_PACK_ROW == 1) { + return decompress_kblock_bit4_packrow1<_ST, _DST_T, true>(srcptr, dstptr, row, col, ld_src, ld_dst, scales, nullptr, + k_offset, kblock, NPad, &dequant_f4_N<48, _DST_T, _F4_T>, + pad_fp4); + } else if constexpr (_PACK_ROW == 2) { + return decompress_kblock_bit4_packrow2<_ST, _DST_T, true>(srcptr, dstptr, row, col, ld_src, ld_dst, scales, nullptr, + k_offset, kblock, NPad, &dequant_f4_N<64, _DST_T, _F4_T>, + pad_fp4); } return JblasNotSupport; } @@ -614,7 +619,7 @@ static inline JBLAS_CODE quantize_fp_u8_colblock(int row, int col, const SRC_T* for (; ij < blocksize; ij++) { auto srcval = (float)srcptr[(j + ij) + i * ld_src]; srcval = srcval * rscale; - auto srcint = int(srcval + 0.5f) + zp; + auto srcint = int(roundf(srcval)) + zp; srcint = std::min(srcint, 0xff); srcint = std::max(srcint, 0); dstptr[(j + ij) + i * ld_dst] = static_cast(srcint); @@ -701,7 +706,7 @@ static inline JBLAS_CODE quantize_fp_s8_colblock(int row, int col, const SRC_T* for (; ij < blocksize; ij++) { auto srcval = (float)srcptr[(j + ij) + i * ld_src]; srcval = srcval * rscale; - auto srcint = int(srcval + 0.5f); + auto srcint = int(roundf(srcval)); srcint = std::min(srcint, 127); srcint = std::max(srcint, -127); dstptr[(j + ij) + i * ld_dst] = static_cast(srcint); @@ -881,8 +886,8 @@ static inline JBLAS_CODE broadcast_u8(int num, const uint8_t& srcval, uint8_t* d return JblasSuccess; } -static inline JBLAS_CODE remove_zeropoint_bias(float* accptr, int ldacc, int row, int col, uint8_t* zps, float* scales, - int lds, const float* reduce) { +static inline JBLAS_CODE remove_act_zeropoint_bias(float* accptr, int ldacc, int row, int col, uint8_t* zps, + float* scales, int lds, const float* reduce) { int constexpr VLen = 16; auto col16 = utils::padto_le(col, VLen); for (int i = 0; i < row; i++) { @@ -896,7 +901,7 @@ static inline JBLAS_CODE remove_zeropoint_bias(float* accptr, int ldacc, int row _mm512_storeu_ps(&accptr[i * ldacc + j], vacc); } if (j < col) { - for (; j < col16; j++) { + for (; j < col; j++) { accptr[i * ldacc + j] -= zpf * reduce[j]; } } @@ -904,6 +909,64 @@ static inline JBLAS_CODE remove_zeropoint_bias(float* accptr, int ldacc, int row return JblasSuccess; } +static inline JBLAS_CODE remove_wei_zeropoint_bias(float* accptr, int ldacc, int row, int col, int8_t* zps, + float* scales, int lds, const float* reduce) { + int constexpr VLen = 16; + auto col16 = utils::padto_le(col, VLen); + for (int i = 0; i < row; i++) { + auto vreduce = _mm512_set1_ps(-reduce[i * lds]); + int j = 0; + for (; j < col16; j += VLen) { + auto vzp_s32 = _mm512_cvtepi8_epi32(_mm_loadu_si128((__m128i*)(zps + j))); + auto vzp_f32 = _mm512_cvtepi32_ps(vzp_s32); + auto vzp = _mm512_mul_ps(vzp_f32, _mm512_loadu_ps(scales + j)); + auto vacc = _mm512_loadu_ps(&accptr[i * ldacc + j]); + vacc = _mm512_fmadd_ps(vzp, vreduce, vacc); + _mm512_storeu_ps(&accptr[i * ldacc + j], vacc); + } + if (j < col) { + for (; j < col16; j++) { + accptr[i * ldacc + j] -= float(zps[j]) * scales[j] * reduce[i * lds]; + } + } + } + return JblasSuccess; +} + +static inline JBLAS_CODE remove_zeropoint_bias(float* accptr, int ldacc, int row, int col, uint8_t* zpa, int8_t* zpb, + float* scalea, float* scaleb, int lds, int k, const float* reducea, + const float* reduceb) { + int constexpr VLen = 16; + auto col16 = utils::padto_le(col, VLen); + auto vk = _mm512_set1_ps((float)(k)); + for (int i = 0; i < row; i++) { + auto vreducea = _mm512_set1_ps(-reducea[i * lds]); + auto zpaf = float(zpa[i * lds]) * scalea[i * lds]; + auto vzpa = _mm512_set1_ps(-zpaf); + int j = 0; + for (; j < col16; j += VLen) { + auto vzp_s32 = _mm512_cvtepi8_epi32(_mm_loadu_si128((__m128i*)(zpb + j))); + auto vzp_f32 = _mm512_cvtepi32_ps(vzp_s32); + auto vzpb = _mm512_mul_ps(vzp_f32, _mm512_loadu_ps(scaleb + j)); + auto vreduceb = _mm512_loadu_ps(reduceb + j); + auto vacc = _mm512_loadu_ps(&accptr[i * ldacc + j]); + vacc = _mm512_fmadd_ps(vzpa, vreduceb, vacc); + vacc = _mm512_fmadd_ps(vzpb, vreducea, vacc); + vzpb = _mm512_mul_ps(vzpb, vk); + vacc = _mm512_fmadd_ps(vzpa, vzpb, vacc); + _mm512_storeu_ps(&accptr[i * ldacc + j], vacc); + } + if (j < col) { + for (; j < col16; j++) { + accptr[i * ldacc + j] -= float(zpb[j]) * scaleb[j] * reducea[i * lds]; + accptr[i * ldacc + j] -= zpaf * reduceb[j]; + accptr[i * ldacc + j] -= zpaf * float(zpb[j]) * scaleb[j] * k; + } + } + } + return JblasSuccess; +} + static inline JBLAS_CODE fp32_cvt_bf16_2D_write_back(const void* raw_srcptr, void* raw_dstptr, int row, int col, int srcstride, int dststride, bool zeropadding) { char* srcptr = (char*)raw_srcptr; @@ -931,7 +994,8 @@ static inline JBLAS_CODE fp32_cvt_bf16_2D_write_back(const void* raw_srcptr, voi auto round_bias = _mm512_maskz_loadu_epi32(tail_mask, src + sizeof(float) * simd_proc_elt * j); round_bias = _mm512_and_epi32(bf16_and_helper, _mm512_bsrli_epi128(round_bias, 2)); round_bias = _mm512_add_epi32(round_bias, bf16_add_helper); - auto round_fp32_v = _mm512_add_epi32(round_bias, _mm512_maskz_loadu_epi32(tail_mask, src + sizeof(float) * simd_proc_elt * j)); + auto round_fp32_v = + _mm512_add_epi32(round_bias, _mm512_maskz_loadu_epi32(tail_mask, src + sizeof(float) * simd_proc_elt * j)); auto pack_bf16_tail = _mm512_cvtepi32_epi16(_mm512_srli_epi32(round_fp32_v, 16)); _mm256_mask_storeu_epi16((__m256i*)(dst + (j * simd_proc_elt) * sizeof(jblas::utils::bf16)), tail_mask, pack_bf16_tail); diff --git a/intel_extension_for_transformers/llm/library/jblas/jblas/kernel_jit.h b/intel_extension_for_transformers/llm/library/jblas/jblas/kernel_jit.h index a06a9f3db71..dcab796fc59 100644 --- a/intel_extension_for_transformers/llm/library/jblas/jblas/kernel_jit.h +++ b/intel_extension_for_transformers/llm/library/jblas/jblas/kernel_jit.h @@ -224,7 +224,7 @@ class DequanKBlockS8F32 { if (row2 > 0) { DequanS8F32::forward_avx512f(srcptr, dstptr, row2, col, ld_src, ld_dst, sptr, zptr); } - return JblasNotSupport; + return JblasSuccess; } }; diff --git a/intel_extension_for_transformers/llm/library/jblas/jblas/kernel_ref.h b/intel_extension_for_transformers/llm/library/jblas/jblas/kernel_ref.h index e4a26bb7540..b2ee24e403b 100644 --- a/intel_extension_for_transformers/llm/library/jblas/jblas/kernel_ref.h +++ b/intel_extension_for_transformers/llm/library/jblas/jblas/kernel_ref.h @@ -203,53 +203,22 @@ inline JBLAS_CODE decompress_s4_s8(utils::int4x2* srcptr, int8_t* dstptr, int ro return JblasSuccess; } -inline JBLAS_CODE decompress_kblock_s8_f32(int8_t* srcptr, float* dstptr, int row, int col, int ld_src, int ld_dst, - float* scales, int8_t* zero_points, int k_offset, int kblock, int NPad) { +template +inline JBLAS_CODE decompress_kblock_s8_f32(int8_t* srcptr, _DST_T* dstptr, int row, int col, int ld_src, int ld_dst, + _S_T* scales, int8_t* zero_points, int k_offset, int kblock, int NPad) { for (int i = 0; i < row; i++) { int kpos = (k_offset + i) / kblock; auto sptr = scales + kpos * NPad; for (int j = 0; j < col; j += 1) { float tmp = (float)(srcptr[i * ld_src + j]); if (zero_points != nullptr) tmp -= (float)(zero_points[kpos * NPad + j]); - dstptr[i * ld_dst + j] = tmp * sptr[j + 0]; + dstptr[i * ld_dst + j] = static_cast<_DST_T>(tmp * sptr[j / _PACK_ROW]); } } return JblasSuccess; } -inline JBLAS_CODE decompress_kblock_s8_f32_packrow(int8_t* srcptr, float* dstptr, int row, int col, int ld_src, - int ld_dst, float* scales, int8_t* zero_points, int k_offset, - int kblock, int NPad, int packrow) { - for (int i = 0; i < row; i += packrow) { - if (packrow == 1) { - int kpos = (k_offset + i) / kblock; - auto sptr = scales + kpos * NPad; - for (int j = 0; j < col; j += 1) { - float tmp = (float)(srcptr[i * ld_src + j]); - if (zero_points != nullptr) tmp -= (float)(zero_points[kpos * NPad + j]); - dstptr[i * ld_dst + j + 0] = tmp * sptr[j]; - } - } else { - for (int j = 0; j < col; j++) { - for (int k = 0; k < packrow; k += 2) { - int kpos = (k_offset + i + k) / kblock; - auto sptr = scales + kpos * NPad + j; - auto tmp = (srcptr + i * ld_src + j * packrow + k); - if (zero_points != nullptr) { - dstptr[i * ld_dst + j * packrow + k + 0] = ((float)tmp[0] - (float)zero_points[kpos * NPad + j]) * sptr[0]; - dstptr[i * ld_dst + j * packrow + k + 1] = ((float)tmp[1] - (float)zero_points[kpos * NPad + j]) * sptr[0]; - } else { - dstptr[i * ld_dst + j * packrow + k + 0] = (float)tmp[0] * sptr[0]; - dstptr[i * ld_dst + j * packrow + k + 1] = (float)tmp[1] * sptr[0]; - } - } - } - } - } - return JblasSuccess; -} - -template +template inline JBLAS_CODE decompress_kblock_s4_fp(utils::int4x2* srcptr, _DST_T* dstptr, int row, int col, int ld_src, int ld_dst, _S_T* scales, int8_t* zero_points, int k_offset, int kblock, int NPad) { @@ -260,20 +229,10 @@ inline JBLAS_CODE decompress_kblock_s4_fp(utils::int4x2* srcptr, _DST_T* dstptr, auto tmp = srcptr[i * ld_src / 2 + j / 2]; float scale0, scale1, dst0, dst1; int s0_idx, s1_idx; - if constexpr (std::is_same<_DST_T, utils::bf16>::value) { - s0_idx = j / 2; - s1_idx = j / 2; - } else { - s0_idx = j; - s1_idx = j + 1; - } - if constexpr (std::is_same<_S_T, utils::bf16>::value) { - scale0 = sptr[s0_idx].tofloat(); - scale1 = sptr[s1_idx].tofloat(); - } else { - scale0 = sptr[s0_idx]; - scale1 = sptr[s1_idx]; - } + s0_idx = j / _PACK_ROW; + s1_idx = (j + 1) / _PACK_ROW; + scale0 = float(sptr[s0_idx]); + scale1 = float(sptr[s1_idx]); if (zero_points != nullptr) { dst0 = (float(get_s8(tmp.x)) - float((zero_points + kpos * NPad)[s0_idx])) * scale0; dst1 = (float(get_s8(tmp.y)) - float((zero_points + kpos * NPad)[s1_idx])) * scale1; @@ -281,142 +240,8 @@ inline JBLAS_CODE decompress_kblock_s4_fp(utils::int4x2* srcptr, _DST_T* dstptr, dst0 = float(get_s8(tmp.x)) * scale0; dst1 = float(get_s8(tmp.y)) * scale1; } - if constexpr (std::is_same<_DST_T, utils::bf16>::value) { - utils::bf16 bf16_ret0, bf16_ret1; - bf16_ret0.fromfloat(dst0); - bf16_ret1.fromfloat(dst1); - dstptr[i * ld_dst + j + 0] = bf16_ret0; - dstptr[i * ld_dst + j + 1] = bf16_ret1; - } else { - dstptr[i * ld_dst + j + 0] = dst0; - dstptr[i * ld_dst + j + 1] = dst1; - } - } - } - return JblasSuccess; -} - -template -inline JBLAS_CODE decompress_pern_s4_fp(utils::int4x2* srcptr, _DST_T* dstptr, int row, int col, int ld_src, int ld_dst, - _S_T* scales, int8_t* zero_points, int k_offset, int kblock, int NPad) { - auto sptr = scales; - for (int i = 0; i < row; i++) { - for (int j = 0; j < col; j += 2) { - auto tmp = srcptr[i * ld_src / 2 + j / 2]; - float scale0, scale1, dst0, dst1; - int s0_idx, s1_idx; - if constexpr (std::is_same<_DST_T, utils::bf16>::value) { - s0_idx = j / 2; - s1_idx = j / 2; - } else { - s0_idx = j; - s1_idx = j + 1; - } - if constexpr (std::is_same<_S_T, utils::bf16>::value) { - scale0 = sptr[s0_idx].tofloat(); - scale1 = sptr[s1_idx].tofloat(); - } else { - scale0 = sptr[s0_idx]; - scale1 = sptr[s1_idx]; - } - if (zero_points != nullptr) { - dst0 = (float(get_s8(tmp.x)) - float((zero_points)[j + 0])) * scale0; - dst1 = (float(get_s8(tmp.y)) - float((zero_points)[j + 1])) * scale1; - } else { - dst0 = float(get_s8(tmp.x)) * scale0; - dst1 = float(get_s8(tmp.y)) * scale1; - } - if constexpr (std::is_same<_DST_T, utils::bf16>::value) { - utils::bf16 bf16_ret0, bf16_ret1; - bf16_ret0.fromfloat(dst0); - bf16_ret1.fromfloat(dst1); - dstptr[i * ld_dst + j + 0] = bf16_ret0; - dstptr[i * ld_dst + j + 1] = bf16_ret1; - } else { - dstptr[i * ld_dst + j + 0] = dst0; - dstptr[i * ld_dst + j + 1] = dst1; - } - } - } - return JblasSuccess; -} -template -inline JBLAS_CODE decompress_kblock_s4_fp_packrow(utils::int4x2* srcptr, float* dstptr, int row, int col, int ld_src, - int ld_dst, float* scales, int8_t* zero_points, int k_offset, - int kblock, int NPad, int packrow) { - for (int i = 0; i < row; i += packrow) { - if (packrow == 1) { - int kpos = (k_offset + i) / kblock; - auto sptr = scales + kpos * NPad; - for (int j = 0; j < col; j += 2) { - auto tmp = srcptr[i * ld_src / 2 + j / 2]; - if (zero_points == nullptr) { - dstptr[i * ld_dst + j + 0] = float(get_s8(tmp.x)) * sptr[j + 0]; - dstptr[i * ld_dst + j + 1] = float(get_s8(tmp.y)) * sptr[j + 1]; - } else { - dstptr[i * ld_dst + j + 0] = - (float(get_s8(tmp.x)) - float((zero_points + kpos * NPad)[j + 0])) * sptr[j + 0]; - dstptr[i * ld_dst + j + 1] = - (float(get_s8(tmp.y)) - float((zero_points + kpos * NPad)[j + 1])) * sptr[j + 1]; - } - } - } else { - for (int j = 0; j < col; j++) { - for (int k = 0; k < packrow; k += 2) { - int kpos = (k_offset + i + k) / kblock; - auto sptr = scales + kpos * NPad + j; - auto tmp = srcptr[(i * ld_src + j * packrow + k) / 2]; - if (zero_points == nullptr) { - dstptr[i * ld_dst + j * packrow + k + 0] = float(get_s8(tmp.x)) * sptr[0]; - dstptr[i * ld_dst + j * packrow + k + 1] = float(get_s8(tmp.y)) * sptr[0]; - } else { - dstptr[i * ld_dst + j * packrow + k + 0] = - (float(get_s8(tmp.x)) - float((zero_points + kpos * NPad)[j + 0])) * sptr[0]; - dstptr[i * ld_dst + j * packrow + k + 1] = - (float(get_s8(tmp.y)) - float((zero_points + kpos * NPad)[j + 0])) * sptr[0]; - } - } - } - } - } - return JblasSuccess; -} - -template -inline JBLAS_CODE decompress_pern_s4_fp_packrow(utils::int4x2* srcptr, float* dstptr, int row, int col, int ld_src, - int ld_dst, float* scales, int8_t* zero_points, int k_offset, - int kblock, int NPad, int packrow) { - (void)NPad; - (void)kblock; - for (int i = 0; i < row; i += packrow) { - if (packrow == 1) { - auto sptr = scales; - for (int j = 0; j < col; j += 2) { - auto tmp = srcptr[i * ld_src / 2 + j / 2]; - if (zero_points == nullptr) { - dstptr[i * ld_dst + j + 0] = float(get_s8(tmp.x)) * sptr[j + 0]; - dstptr[i * ld_dst + j + 1] = float(get_s8(tmp.y)) * sptr[j + 1]; - } else { - dstptr[i * ld_dst + j + 0] = (float(get_s8(tmp.x)) - float((zero_points)[j + 0])) * sptr[j + 0]; - dstptr[i * ld_dst + j + 1] = (float(get_s8(tmp.y)) - float((zero_points)[j + 1])) * sptr[j + 1]; - } - } - } else { - for (int j = 0; j < col; j++) { - auto sptr = scales + j; - for (int k = 0; k < packrow; k += 2) { - auto tmp = srcptr[(i * ld_src + j * packrow + k) / 2]; - if (zero_points == nullptr) { - dstptr[i * ld_dst + j * packrow + k + 0] = float(get_s8(tmp.x)) * sptr[0]; - dstptr[i * ld_dst + j * packrow + k + 1] = float(get_s8(tmp.y)) * sptr[0]; - } else { - dstptr[i * ld_dst + j * packrow + k + 0] = - (float(get_s8(tmp.x)) - float((zero_points)[j + 0])) * sptr[0]; - dstptr[i * ld_dst + j * packrow + k + 1] = - (float(get_s8(tmp.y)) - float((zero_points)[j + 0])) * sptr[0]; - } - } - } + dstptr[i * ld_dst + j + 0] = static_cast<_DST_T>(dst0); + dstptr[i * ld_dst + j + 1] = static_cast<_DST_T>(dst1); } } return JblasSuccess; @@ -654,7 +479,7 @@ inline int8_t f4_quantize(float x) { return int8_t(0); } -template +template inline JBLAS_CODE decompress_kblock_f4_fp(utils::f4x2* srcptr, _DST_T* dstptr, int row, int col, int ld_src, int ld_dst, _S_T* scales, int k_offset, int kblock, int NPad) { for (int i = 0; i < row; i++) { @@ -664,60 +489,14 @@ inline JBLAS_CODE decompress_kblock_f4_fp(utils::f4x2* srcptr, _DST_T* dstptr, i auto tmp = srcptr[i * ld_src / 2 + j / 2]; float scale0, scale1, dst0, dst1; int s0_idx, s1_idx; - if constexpr (std::is_same<_DST_T, utils::bf16>::value) { - s0_idx = j / 2; - s1_idx = j / 2; - } else { - s0_idx = j; - s1_idx = j + 1; - } - if constexpr (std::is_same<_S_T, utils::bf16>::value) { - scale0 = sptr[s0_idx].tofloat(); - scale1 = sptr[s1_idx].tofloat(); - } else { - scale0 = sptr[s0_idx]; - scale1 = sptr[s1_idx]; - } + s0_idx = j / _PACK_ROW; + s1_idx = (j + 1) / _PACK_ROW; + scale0 = float(sptr[s0_idx]); + scale1 = float(sptr[s1_idx]); dst0 = f4_dequantize(tmp.x, scale0); dst1 = f4_dequantize(tmp.y, scale1); - if constexpr (std::is_same<_DST_T, utils::bf16>::value) { - utils::bf16 bf16_ret0, bf16_ret1; - bf16_ret0.fromfloat(dst0); - bf16_ret1.fromfloat(dst1); - dstptr[i * ld_dst + j + 0] = bf16_ret0; - dstptr[i * ld_dst + j + 1] = bf16_ret1; - } else { - dstptr[i * ld_dst + j + 0] = dst0; - dstptr[i * ld_dst + j + 1] = dst1; - } - } - } - return JblasSuccess; -} - -template -inline JBLAS_CODE decompress_kblock_f4_fp_packrow(utils::f4x2* srcptr, float* dstptr, int row, int col, int ld_src, - int ld_dst, float* scales, int k_offset, int kblock, int NPad, - int packrow) { - for (int i = 0; i < row; i += packrow) { - if (packrow == 1) { - int kpos = (k_offset + i) / kblock; - auto sptr = scales + kpos * NPad; - for (int j = 0; j < col; j += 2) { - auto tmp = srcptr[i * ld_src / 2 + j / 2]; - dstptr[i * ld_dst + j + 0] = f4_dequantize(tmp.x, sptr[j + 0]); - dstptr[i * ld_dst + j + 1] = f4_dequantize(tmp.y, sptr[j + 1]); - } - } else { - for (int j = 0; j < col; j++) { - for (int k = 0; k < packrow; k += 2) { - int kpos = (k_offset + i + k) / kblock; - auto sptr = scales + kpos * NPad + j; - auto tmp = srcptr[(i * ld_src + j * packrow + k) / 2]; - dstptr[i * ld_dst + j * packrow + k + 0] = f4_dequantize(tmp.x, sptr[0]); - dstptr[i * ld_dst + j * packrow + k + 1] = f4_dequantize(tmp.y, sptr[0]); - } - } + dstptr[i * ld_dst + j + 0] = static_cast<_DST_T>(dst0); + dstptr[i * ld_dst + j + 1] = static_cast<_DST_T>(dst1); } } return JblasSuccess; @@ -1090,8 +869,8 @@ static inline JBLAS_CODE row_reduce_sum(const _RT* srcptr, int ldsrc, int row, i return JblasSuccess; } -static inline JBLAS_CODE remove_zeropoint_bias(float* accptr, int ldacc, int row, int col, uint8_t* zps, float* scales, - int lds, const float* reduce) { +static inline JBLAS_CODE remove_act_zeropoint_bias(float* accptr, int ldacc, int row, int col, uint8_t* zps, + float* scales, int lds, const float* reduce) { for (int i = 0; i < row; i++) { auto zpf = float(zps[i * lds]) * scales[i * lds]; for (int j = 0; j < col; j++) { @@ -1100,6 +879,33 @@ static inline JBLAS_CODE remove_zeropoint_bias(float* accptr, int ldacc, int row } return JblasSuccess; } + +static inline JBLAS_CODE remove_wei_zeropoint_bias(float* accptr, int ldacc, int row, int col, int8_t* zps, + float* scales, int lds, const float* reduce) { + for (int i = 0; i < row; i++) { + auto reducef = reduce[i * lds]; + for (int j = 0; j < col; j++) { + accptr[i * ldacc + j] -= float(zps[j]) * scales[j] * reducef; + } + } + return JblasSuccess; +} + +static inline JBLAS_CODE remove_zeropoint_bias(float* accptr, int ldacc, int row, int col, uint8_t* zpa, int8_t* zpb, + float* scalea, float* scaleb, int lds, int k, const float* reducea, + const float* reduceb) { + for (int i = 0; i < row; i++) { + auto reduceaf = reducea[i * lds]; + auto zpaf = float(zpa[i * lds]) * scalea[i * lds]; + for (int j = 0; j < col; j++) { + auto zpbf = float(zpb[j]) * scaleb[j]; + accptr[i * ldacc + j] -= zpbf * reduceaf; + accptr[i * ldacc + j] -= zpaf * reduceb[j]; + accptr[i * ldacc + j] -= zpaf * zpbf * k; + } + } + return JblasSuccess; +} } // namespace ref } // namespace kernel } // namespace jblas diff --git a/intel_extension_for_transformers/llm/library/jblas/jblas/kernel_wrapper.h b/intel_extension_for_transformers/llm/library/jblas/jblas/kernel_wrapper.h index 3db3d0d846b..52557064be9 100644 --- a/intel_extension_for_transformers/llm/library/jblas/jblas/kernel_wrapper.h +++ b/intel_extension_for_transformers/llm/library/jblas/jblas/kernel_wrapper.h @@ -118,15 +118,14 @@ class Memcpy2D { #if CompileAVX2() if constexpr (utils::isa_base::avx2) { ret = kernel::jit::JitMemcpy2DAvx2::forward1<_SRC_T, _DST_T, OP_T>(srcptr, dstptr, row, col, srcstep, dststep, - const_elt_v); + const_elt_v); if (ret == JblasSuccess) { return ret; } } #endif - assert(false);//no ref implementation + assert(false); // no ref implementation return JblasNotSupport; - } }; @@ -343,99 +342,58 @@ class AccumulateDequantizeS32F32 { } }; -template // zero points always be int8_t, not compressed +template // zero points always be int8_t, not compressed class DecompressKBlockS4FP { public: - template + template static inline JBLAS_CODE forward(utils::int4x2* srcptr, _DST_T* dstptr, int row, int col, int ld_src, int ld_dst, - _T* scales, int8_t* zero_points, int k_offset, int kblock, int NPad) { + _SCA_T* scales, int8_t* zero_points, int k_offset, int kblock, int NPad) { JBLAS_CODE ret = JblasNotSupport; - #if CompileAVX512F() if constexpr (utils::isa_base::avx512f) { - ret = avx512f::decompress_kblock_s4_fp<_T, _DST_T, S4_T>(srcptr, dstptr, row, col, ld_src, ld_dst, scales, - zero_points, k_offset, kblock, NPad); - return ret; + ret = avx512f::decompress_kblock_s4_fp( + srcptr, dstptr, row, col, ld_src, ld_dst, scales, zero_points, k_offset, kblock, NPad); + if (ret == JblasSuccess) return ret; } #endif #if CompileAVX2() - if constexpr (utils::isa_base::avx2 && std::is_same_v<_DST_T, float>) { - ret = avx2::decompress_kblock_bit4_fp32(srcptr, dstptr, row, col, ld_src, ld_dst, scales, zero_points, k_offset, - kblock, NPad, &avx2::dequant_s8_N_avx2<48>, - &avx2::convert_s4_s8_16_sse); - return ret; + // AVX2 device only focus on fp32 data and layout + if constexpr (utils::isa_base::avx2 && std::is_same_v<_DST_T, float> && _PACK_ROW == 1) { + ret = avx2::decompress_kblock_bit4_packrow1(srcptr, dstptr, row, col, ld_src, ld_dst, scales, zero_points, + k_offset, kblock, NPad, &avx2::dequant_s8_N_avx2<48>, + &avx2::convert_s4_s8_16_sse); + if (ret == JblasSuccess) return ret; } #endif - ret = ref::decompress_kblock_s4_fp(srcptr, dstptr, row, col, ld_src, ld_dst, scales, zero_points, - k_offset, kblock, NPad); + ret = ref::decompress_kblock_s4_fp(srcptr, dstptr, row, col, ld_src, ld_dst, + scales, zero_points, k_offset, kblock, NPad); return ret; } }; -template // zero points always be int8_t, not compressed -class DecompressPerNS4FP { - public: - template - static inline JBLAS_CODE forward(utils::int4x2* srcptr, _DST_T* dstptr, int row, int col, int ld_src, int ld_dst, - _T* scales, int8_t* zero_points, int k_offset, int kblock, int NPad) { - return ref::decompress_pern_s4_fp(srcptr, dstptr, row, col, ld_src, ld_dst, scales, zero_points, - k_offset, kblock, NPad); - } -}; - -template -class DecompressKBlockS4FPPackRow { - public: - template - static inline JBLAS_CODE forward(utils::int4x2* srcptr, _DST_T* dstptr, int row, int col, int ld_src, int ld_dst, - _T* scales, int8_t* zero_points, int k_offset, int kblock, int NPad, int packrow) { - return ref::decompress_kblock_s4_fp_packrow(srcptr, dstptr, row, col, ld_src, ld_dst, scales, zero_points, - k_offset, kblock, NPad, packrow); - } -}; - -template -class DecompressPerNS4FPPackRow { - public: - template - static inline JBLAS_CODE forward(utils::int4x2* srcptr, _DST_T* dstptr, int row, int col, int ld_src, int ld_dst, - _T* scales, int8_t* zero_points, int k_offset, int kblock, int NPad, int packrow) { - return ref::decompress_pern_s4_fp_packrow(srcptr, dstptr, row, col, ld_src, ld_dst, scales, zero_points, - k_offset, kblock, NPad, packrow); - } -}; - -template -class DecompressKBlockF4FPPackRow { - public: - template - static inline JBLAS_CODE forward(utils::f4x2* srcptr, _DST_T* dstptr, int row, int col, int ld_src, int ld_dst, - _T* scales, int k_offset, int kblock, int NPad, int packrow) { - return ref::decompress_kblock_f4_fp_packrow(srcptr, dstptr, row, col, ld_src, ld_dst, scales, k_offset, - kblock, NPad, packrow); - } -}; - -template +template class DecompressKBlockF4Fp { public: - template + template static inline JBLAS_CODE forward(utils::f4x2* srcptr, _DST_T* dstptr, int row, int col, int ld_src, int ld_dst, - _T* scales, int k_offset, int kblock, int NPad) { + SCA_T* scales, int k_offset, int kblock, int NPad) { + JBLAS_CODE ret = JblasNotSupport; #if CompileAVX512F() if constexpr (utils::isa_base::avx512f) { - return avx512f::decompress_kblock_f4_fp<_T, _DST_T, F4_T>(srcptr, dstptr, row, col, ld_src, ld_dst, scales, - k_offset, kblock, NPad); + ret = avx512f::decompress_kblock_f4_fp(srcptr, dstptr, row, col, ld_src, ld_dst, + scales, k_offset, kblock, NPad); + if (ret == JblasSuccess) return ret; } #endif #if CompileAVX2() if constexpr (utils::isa_base::avx2) { - return avx2::decompress_kblock_f4_fp<_T, _DST_T, F4_T>(srcptr, dstptr, row, col, ld_src, ld_dst, scales, k_offset, - kblock, NPad); + ret = avx2::decompress_kblock_f4_fp(srcptr, dstptr, row, col, ld_src, ld_dst, + scales, k_offset, kblock, NPad); + if (ret == JblasSuccess) return ret; } #endif - return ref::decompress_kblock_f4_fp(srcptr, dstptr, row, col, ld_src, ld_dst, scales, k_offset, - kblock, NPad); + return ref::decompress_kblock_f4_fp(srcptr, dstptr, row, col, ld_src, ld_dst, + scales, k_offset, kblock, NPad); } }; @@ -459,36 +417,26 @@ class DecompressKBlockS4S8 { return ref::decompress_s4_s8(srcptr, dstptr, row, col, ld_src, ld_dst); } }; - +template class DecompressKBlockS8F32 { public: - template - static inline JBLAS_CODE forward(int8_t* srcptr, float* dstptr, int row, int col, int ld_src, int ld_dst, _T* scales, - int8_t* zero_points, int k_offset, int kblock, int NPad) { + template + static inline JBLAS_CODE forward(int8_t* srcptr, float* dstptr, int row, int col, int ld_src, int ld_dst, + SCA_T* scales, int8_t* zero_points, int k_offset, int kblock, int NPad) { #if CompileAVX512F() - if (utils::isa_base::avx512f) { + if (utils::isa_base::avx512f && PACK_ROW == 1) { return jit::DequanKBlockS8F32::forward_avx512f(srcptr, dstptr, row, col, ld_src, ld_dst, scales, zero_points, k_offset, kblock, NPad); } #endif #if CompileAVX2() - if (utils::isa_base::avx2) { + if (utils::isa_base::avx2 && PACK_ROW == 1) { return avx2::dequant_kblock_s8_f32(srcptr, dstptr, row, col, ld_src, ld_dst, scales, zero_points, k_offset, kblock, NPad); } #endif - return ref::decompress_kblock_s8_f32(srcptr, dstptr, row, col, ld_src, ld_dst, scales, zero_points, k_offset, - kblock, NPad); - } -}; - -class DecompressKBlockS8FP32PackRow { - public: - template - static inline JBLAS_CODE forward(int8_t* srcptr, float* dstptr, int row, int col, int ld_src, int ld_dst, _T* scales, - int8_t* zero_points, int k_offset, int kblock, int NPad, int packrow) { - return ref::decompress_kblock_s8_f32_packrow(srcptr, dstptr, row, col, ld_src, ld_dst, scales, zero_points, - k_offset, kblock, NPad, packrow); + return ref::decompress_kblock_s8_f32(srcptr, dstptr, row, col, ld_src, ld_dst, scales, + zero_points, k_offset, kblock, NPad); } }; @@ -572,13 +520,30 @@ class RowReduceSum { class RemoveZeroPointBias { public: + template + static inline JBLAS_CODE forward(float* accptr, int ldacc, int row, int col, int8_t* zps, float* scales, int lds, + const float* reduce) { + if constexpr (utils::isa_base::avx512f) { + return avx512f::remove_wei_zeropoint_bias(accptr, ldacc, row, col, zps, scales, lds, reduce); + } + return ref::remove_wei_zeropoint_bias(accptr, ldacc, row, col, zps, scales, lds, reduce); + } template static inline JBLAS_CODE forward(float* accptr, int ldacc, int row, int col, uint8_t* zps, float* scales, int lds, const float* reduce) { if constexpr (utils::isa_base::avx512f) { - return avx512f::remove_zeropoint_bias(accptr, ldacc, row, col, zps, scales, lds, reduce); + return avx512f::remove_act_zeropoint_bias(accptr, ldacc, row, col, zps, scales, lds, reduce); + } + return ref::remove_act_zeropoint_bias(accptr, ldacc, row, col, zps, scales, lds, reduce); + } + template + static inline JBLAS_CODE forward(float* accptr, int ldacc, int row, int col, uint8_t* zpa, int8_t* zpb, float* scalea, + float* scaleb, int lds, int k, const float* reducea, const float* reduceb) { + if constexpr (utils::isa_base::avx512f) { + return avx512f::remove_zeropoint_bias(accptr, ldacc, row, col, zpa, zpb, scalea, scaleb, lds, k, reducea, + reduceb); } - return ref::remove_zeropoint_bias(accptr, ldacc, row, col, zps, scales, lds, reduce); + return ref::remove_zeropoint_bias(accptr, ldacc, row, col, zpa, zpb, scalea, scaleb, lds, k, reducea, reduceb); } }; diff --git a/intel_extension_for_transformers/llm/operator/cscr/dispatcher/src/jblas_weightonly_dispatcher.cpp b/intel_extension_for_transformers/llm/operator/cscr/dispatcher/src/jblas_weightonly_dispatcher.cpp index bdb5688f0e2..4389efe6cb3 100644 --- a/intel_extension_for_transformers/llm/operator/cscr/dispatcher/src/jblas_weightonly_dispatcher.cpp +++ b/intel_extension_for_transformers/llm/operator/cscr/dispatcher/src/jblas_weightonly_dispatcher.cpp @@ -19,10 +19,10 @@ #include #include #include +#include #include #include #include -#include #include "jblas/jit_blas.h" #include "jblas/jit_blas_epilogue.h" #include "jblas/jit_blas_gemm.h" @@ -91,17 +91,18 @@ void qbits_quantize(qbits_config_param* p, qbits_runtime_ctx* ctx) { static PrologueB compress_kernel; set_nk(ctx, ctx->weight); + if (initer.verbose) timer.start(); auto do_quant = [&](typename PrologueB::StorageWeight* ptr) { - std::vector buffer(ptr->mSize); - ptr->assign(buffer.data()); + int8_t* buffer = jblas::utils::amalloc(ptr->mSize); + ptr->assign(buffer); if (ctx->transpose) compress_kernel.packTransposeWeight(ctx->n, ctx->k, ctx->weight->data_ptr(), ctx->k, ptr); else compress_kernel.packWeight(ctx->n, ctx->k, ctx->weight->data_ptr(), ctx->n, ptr); *(ctx->output) = torch::zeros(ptr->mSize, torch::kInt8); ptr->serialize(ctx->output->data_ptr()); + jblas::utils::afree(buffer); }; - if constexpr (!perchannel_Gemmcore) { auto storage = compress_kernel.createStorage(ctx->n, ctx->k, ctx->blocksize); do_quant(&storage); @@ -109,6 +110,13 @@ void qbits_quantize(qbits_config_param* p, qbits_runtime_ctx* ctx) { auto storage = compress_kernel.createStorage(ctx->n, ctx->k, false); do_quant(&storage); } + if (initer.verbose) { + timer.stop(); + auto cost_time = timer.get_elapsed_time(); + std::cout << "QBits quantize verbose\nn:" << ctx->n << " k:" << ctx->k << " weight_type:" << p->weight_type + << " blocksize:" << ctx->blocksize << " src_type:" << dispatcher_utils::get_torch_dt_name(ctx->weight) + << " execute time:" << cost_time << "ms" << std::endl; + } } template @@ -116,12 +124,20 @@ void qbits_dequantize(qbits_config_param* p, qbits_runtime_ctx* ctx) { using PrologueB = typename KERNEL::WeightType; static PrologueB decompress_kernel; set_nk(ctx, ctx->output); + if (initer.verbose) timer.start(); if (ctx->transpose) decompress_kernel.unpackTransposeWeight(int(ctx->n), int(ctx->k), ctx->deseries_wei, ctx->output->data_ptr(), int(ctx->k)); else decompress_kernel.unpackWeight(int(ctx->n), int(ctx->k), ctx->deseries_wei, ctx->output->data_ptr(), int(ctx->n)); + if (initer.verbose) { + timer.stop(); + auto cost_time = timer.get_elapsed_time(); + std::cout << "QBits dequantize verbose\nn:" << ctx->n << " k:" << ctx->k << " weight_type:" << p->weight_type + << " blocksize:" << ctx->blocksize << " dst_type:" << dispatcher_utils::get_torch_dt_name(ctx->output) + << " execute time:" << cost_time << "ms" << std::endl; + } } template @@ -136,7 +152,7 @@ void do_compute(qbits_config_param* p, qbits_runtime_ctx* ctx, const ParamA para if (initer.verbose) { timer.stop(); auto cost_time = timer.get_elapsed_time(); - std::cout << "QBits verbose\nm:" << ctx->m << " n:" << ctx->n << " k:" << ctx->k + std::cout << "QBits linear verbose\nm:" << ctx->m << " n:" << ctx->n << " k:" << ctx->k << " weight_type:" << p->weight_type << " compute_type:" << p->compute_type << " blocksize:" << ctx->blocksize << " src_type:" << dispatcher_utils::get_torch_dt_name(ctx->activation) << " dst_type:" << dispatcher_utils::get_torch_dt_name(ctx->output) << " execute time:" << cost_time @@ -201,7 +217,7 @@ void parse_paramA(qbits_config_param* p, qbits_runtime_ctx* ctx) { "Qbits: workspace size should large than " + std::to_string(need_size) + " bytes"); return workspace; } else { - tmpbuf = malloc(need_size); + tmpbuf = jblas::utils::amalloc(need_size); return tmpbuf; } }; @@ -218,7 +234,7 @@ void parse_paramA(qbits_config_param* p, qbits_runtime_ctx* ctx) { ParamA param_a = {reinterpret_cast(ctx->activation->data_ptr()), ctx->lda, &quantA}; parse_paramC(p, ctx, param_a); } - if (tmpbuf != NULL) free(tmpbuf); + if (tmpbuf != NULL) jblas::utils::afree(tmpbuf); } } diff --git a/intel_extension_for_transformers/llm/runtime/graph/core/layers/inner_product.cpp b/intel_extension_for_transformers/llm/runtime/graph/core/layers/inner_product.cpp index 9e62eec3288..257d54e119e 100644 --- a/intel_extension_for_transformers/llm/runtime/graph/core/layers/inner_product.cpp +++ b/intel_extension_for_transformers/llm/runtime/graph/core/layers/inner_product.cpp @@ -182,8 +182,15 @@ static JBLAS_CODE jblas_s8fp32perN_f32f32_forward(float* activation, SS8Fp32PerN static GemmKernel kernel; auto quanA = kernel.getActivationPtr()->createStorage(_m, _k); quanA.assign((int8_t*)workspace); - ret = kernel.compute({_m, _n, _k, activation, lda, &quanA, weiptr, output, ldo, quanA.mZPtr, - quanA.mSPtr, quanA.mCStep, weiptr->mRPtr, weiptr->mSPtr}); + ret = kernel.compute( + {_m, + _n, + _k, + activation, + lda, + &quanA, + weiptr, + {output, ldo, quanA.mCStep, quanA.mSPtr, weiptr->mSPtr, quanA.mZPtr, weiptr->mRPtr}}); } } return ret; @@ -207,8 +214,15 @@ static JBLAS_CODE jblas_s4fp32perN_f32f32_forward(float* activation, SS4Fp32PerN static GemmKernel kernel; auto quanA = kernel.getActivationPtr()->createStorage(_m, _k); quanA.assign((int8_t*)workspace); - ret = kernel.compute({_m, _n, _k, activation, lda, &quanA, weiptr, output, ldo, quanA.mZPtr, - quanA.mSPtr, quanA.mCStep, weiptr->mRPtr, weiptr->mSPtr}); + ret = kernel.compute( + {_m, + _n, + _k, + activation, + lda, + &quanA, + weiptr, + {output, ldo, quanA.mCStep, quanA.mSPtr, weiptr->mSPtr, quanA.mZPtr, weiptr->mRPtr}}); } } return ret; @@ -348,9 +362,17 @@ JBLAS_CODE jblas_fusion_add_s8fp32pern_f32f32_forward(float* activation, SS8Fp32 static GemmKernel kernel; auto quanA = kernel.getActivationPtr()->createStorage(_m, _k); quanA.assign((int8_t*)workspace); - ret = kernel.compute({_m, _n, _k, activation, lda, &quanA, weiptr, output, ldo, quanA.mZPtr, - quanA.mSPtr, quanA.mCStep, weiptr->mRPtr, weiptr->mSPtr, bias, - broadcast_bias ? 0 : ldo}); + ret = kernel.compute( + {_m, + _n, + _k, + activation, + lda, + &quanA, + weiptr, + {{output, ldo, quanA.mCStep, quanA.mSPtr, weiptr->mSPtr, quanA.mZPtr, weiptr->mRPtr}, + bias, + broadcast_bias ? 0 : ldo}}); } } return ret; diff --git a/intel_extension_for_transformers/llm/runtime/graph/core/layers/ip_fusion_ffn.cpp b/intel_extension_for_transformers/llm/runtime/graph/core/layers/ip_fusion_ffn.cpp index 5c084825ae3..aab1f364618 100644 --- a/intel_extension_for_transformers/llm/runtime/graph/core/layers/ip_fusion_ffn.cpp +++ b/intel_extension_for_transformers/llm/runtime/graph/core/layers/ip_fusion_ffn.cpp @@ -204,11 +204,22 @@ JBLAS_CODE jblas_fusion_FFN_SiLu_s8fp32pern_f32f32_forward(float* activation, SS auto offset = workspace == NULL ? 0 : quanA1.mSize; auto quanA2 = finter.getActivationPtr()->createStorage(seq, fmid); quanA2.assign((int8_t*)workspace + offset); - ret = finter.compute({seq, fin, fmid, fout, activation, lda, &quanA1, - tmp1, ldtmp1, &quanA2, w1ptr, w2ptr, w3ptr, tmp1, - ldtmp1, quanA1.mZPtr, quanA1.mSPtr, quanA1.mCStep, w1ptr->mRPtr, w1ptr->mSPtr, output, - ldo, quanA2.mZPtr, quanA2.mSPtr, quanA2.mCStep, w2ptr->mRPtr, w2ptr->mSPtr, tmp2, - ldtmp2, quanA1.mZPtr, quanA1.mSPtr, quanA1.mCStep, w3ptr->mRPtr, w3ptr->mSPtr}); + ret = finter.compute({seq, + fin, + fmid, + fout, + activation, + lda, + &quanA1, + tmp1, + ldtmp1, + &quanA2, + w1ptr, + w2ptr, + w3ptr, + {tmp1, ldtmp1, quanA1.mCStep, quanA1.mSPtr, w1ptr->mSPtr, quanA1.mZPtr, w1ptr->mRPtr}, + {output, ldo, quanA2.mCStep, quanA2.mSPtr, w2ptr->mSPtr, quanA2.mZPtr, w2ptr->mRPtr}, + {tmp2, ldtmp2, quanA1.mCStep, quanA1.mSPtr, w3ptr->mSPtr, quanA1.mZPtr, w3ptr->mRPtr}}); } } return ret; @@ -255,11 +266,22 @@ JBLAS_CODE jblas_fusion_FFN_SiLu_s4clipfp32pern_f32f32_forward(float* activation auto offset = workspace == NULL ? 0 : quanA1.mSize; auto quanA2 = finter.getActivationPtr()->createStorage(seq, fmid); quanA2.assign((int8_t*)workspace + offset); - ret = finter.compute({seq, fin, fmid, fout, activation, lda, &quanA1, - tmp1, ldtmp1, &quanA2, w1ptr, w2ptr, w3ptr, tmp1, - ldtmp1, quanA1.mZPtr, quanA1.mSPtr, quanA1.mCStep, w1ptr->mRPtr, w1ptr->mSPtr, output, - ldo, quanA2.mZPtr, quanA2.mSPtr, quanA2.mCStep, w2ptr->mRPtr, w2ptr->mSPtr, tmp2, - ldtmp2, quanA1.mZPtr, quanA1.mSPtr, quanA1.mCStep, w3ptr->mRPtr, w3ptr->mSPtr}); + ret = finter.compute({seq, + fin, + fmid, + fout, + activation, + lda, + &quanA1, + tmp1, + ldtmp1, + &quanA2, + w1ptr, + w2ptr, + w3ptr, + {tmp1, ldtmp1, quanA1.mCStep, quanA1.mSPtr, w1ptr->mSPtr, quanA1.mZPtr, w1ptr->mRPtr}, + {output, ldo, quanA2.mCStep, quanA2.mSPtr, w2ptr->mSPtr, quanA2.mZPtr, w2ptr->mRPtr}, + {tmp2, ldtmp2, quanA1.mCStep, quanA1.mSPtr, w3ptr->mSPtr, quanA1.mZPtr, w3ptr->mRPtr}}); } } return ret; @@ -682,16 +704,24 @@ JBLAS_CODE jblas_fusion_FFN_Add_GeLu_s8fp32pern_f32f32_forward(float* activation auto offset = workspace == NULL ? 0 : quanA1.mSize; auto quanA2 = finter.getActivationPtr()->createStorage(seq, fmid); quanA2.assign((int8_t*)workspace + offset); - ret = finter.compute({seq, fin, fmid, - fout, activation, lda, - &quanA1, tmp1, ldtmp1, - &quanA2, w1tmp, w2tmp, - tmp1, ldtmp1, quanA1.mZPtr, - quanA1.mSPtr, quanA1.mCStep, w1tmp->mRPtr, - w1tmp->mSPtr, b1ptr, broadcast_bias ? 0 : ldtmp1, - output, ldo, quanA2.mZPtr, - quanA2.mSPtr, quanA2.mCStep, w2tmp->mRPtr, - w2tmp->mSPtr, b2ptr, broadcast_bias ? 0 : ldo}); + ret = finter.compute({seq, + fin, + fmid, + fout, + activation, + lda, + &quanA1, + tmp1, + ldtmp1, + &quanA2, + w1tmp, + w2tmp, + {{tmp1, ldtmp1, quanA1.mCStep, quanA1.mSPtr, w1tmp->mSPtr, quanA1.mZPtr, w1tmp->mRPtr}, + b1ptr, + broadcast_bias ? 0 : ldtmp1}, + {{output, ldo, quanA2.mCStep, quanA2.mSPtr, w2tmp->mSPtr, quanA2.mZPtr, w2tmp->mRPtr}, + b2ptr, + broadcast_bias ? 0 : ldo}}); } } return ret; @@ -765,16 +795,24 @@ JBLAS_CODE jblas_fusion_FFN_Add_GeLu_s4clipfp32pern_f32f32_forward(float* activa auto offset = workspace == NULL ? 0 : quanA1.mSize; auto quanA2 = finter.getActivationPtr()->createStorage(seq, fmid); quanA2.assign((int8_t*)workspace + offset); - ret = finter.compute({seq, fin, fmid, - fout, activation, lda, - &quanA1, tmp1, ldtmp1, - &quanA2, w1tmp, w2tmp, - tmp1, ldtmp1, quanA1.mZPtr, - quanA1.mSPtr, quanA1.mCStep, w1tmp->mRPtr, - w1tmp->mSPtr, b1ptr, broadcast_bias ? 0 : ldtmp1, - output, ldo, quanA2.mZPtr, - quanA2.mSPtr, quanA2.mCStep, w2tmp->mRPtr, - w2tmp->mSPtr, b2ptr, broadcast_bias ? 0 : ldo}); + ret = finter.compute({seq, + fin, + fmid, + fout, + activation, + lda, + &quanA1, + tmp1, + ldtmp1, + &quanA2, + w1tmp, + w2tmp, + {{tmp1, ldtmp1, quanA1.mCStep, quanA1.mSPtr, w1tmp->mSPtr, quanA1.mZPtr, w1tmp->mRPtr}, + b1ptr, + broadcast_bias ? 0 : ldtmp1}, + {{output, ldo, quanA2.mCStep, quanA2.mSPtr, w2tmp->mSPtr, quanA2.mZPtr, w2tmp->mRPtr}, + b2ptr, + broadcast_bias ? 0 : ldo}}); } } return ret; diff --git a/intel_extension_for_transformers/llm/runtime/graph/core/layers/ip_fusion_qkv.cpp b/intel_extension_for_transformers/llm/runtime/graph/core/layers/ip_fusion_qkv.cpp index 666ba44afab..fe7cd60453c 100644 --- a/intel_extension_for_transformers/llm/runtime/graph/core/layers/ip_fusion_qkv.cpp +++ b/intel_extension_for_transformers/llm/runtime/graph/core/layers/ip_fusion_qkv.cpp @@ -255,9 +255,9 @@ JBLAS_CODE jblas_QKVs8fp32pern_f32f32_forward(float* activation, SS8Fp32PerN* wq wvptr, }; GemmKernel::CParam oparams[3]{ - {output, ldo, quanA.mZPtr, quanA.mSPtr, quanA.mCStep, wqptr->mRPtr, wqptr->mSPtr}, - {output + _m * _n, ldo, quanA.mZPtr, quanA.mSPtr, quanA.mCStep, wkptr->mRPtr, wkptr->mSPtr}, - {output + 2 * _m * _n, ldo, quanA.mZPtr, quanA.mSPtr, quanA.mCStep, wvptr->mRPtr, wvptr->mSPtr}, + {output, ldo, quanA.mCStep, quanA.mSPtr, wqptr->mSPtr, quanA.mZPtr, wqptr->mRPtr}, + {output + _m * _n, ldo, quanA.mCStep, quanA.mSPtr, wkptr->mSPtr, quanA.mZPtr, wkptr->mRPtr}, + {output + 2 * _m * _n, ldo, quanA.mCStep, quanA.mSPtr, wvptr->mSPtr, quanA.mZPtr, wvptr->mRPtr}, }; ret = kernel.compute({_m, _n, _k, 3, activation, lda, &quanA, wparams, oparams, NULL}); } @@ -298,9 +298,9 @@ JBLAS_CODE jblas_QKVs4clipfp32pern_f32f32_forward(float* activation, SS4Fp32PerN wvptr, }; GemmKernel::CParam oparams[3]{ - {output, ldo, quanA.mZPtr, quanA.mSPtr, quanA.mCStep, wqptr->mRPtr, wqptr->mSPtr}, - {output + _m * _n, ldo, quanA.mZPtr, quanA.mSPtr, quanA.mCStep, wkptr->mRPtr, wkptr->mSPtr}, - {output + 2 * _m * _n, ldo, quanA.mZPtr, quanA.mSPtr, quanA.mCStep, wvptr->mRPtr, wvptr->mSPtr}, + {output, ldo, quanA.mCStep, quanA.mSPtr, wqptr->mSPtr, quanA.mZPtr, wqptr->mRPtr}, + {output + _m * _n, ldo, quanA.mCStep, quanA.mSPtr, wkptr->mSPtr, quanA.mZPtr, wkptr->mRPtr}, + {output + 2 * _m * _n, ldo, quanA.mCStep, quanA.mSPtr, wvptr->mSPtr, quanA.mZPtr, wvptr->mRPtr}, }; ret = kernel.compute({_m, _n, _k, 3, activation, lda, &quanA, wparams, oparams, NULL}); } diff --git a/intel_extension_for_transformers/llm/runtime/graph/core/layers/mha_dense.cpp b/intel_extension_for_transformers/llm/runtime/graph/core/layers/mha_dense.cpp index 803630a74a6..2a8b557bb2b 100644 --- a/intel_extension_for_transformers/llm/runtime/graph/core/layers/mha_dense.cpp +++ b/intel_extension_for_transformers/llm/runtime/graph/core/layers/mha_dense.cpp @@ -709,12 +709,12 @@ class MHAInterface { // TODO(Yi): init packed weight with p.tmp PackedWeightBatch K_pack(jblas::gemm::GemmCoreType::AMX_BF16_16x64); // packed K K_pack.resize(padto(p.sl_kv, GemmQK::NTILE), padto(p.head_size, GemmQK::KTILE), num_heads); - jblas::utils::avector bufferK(K_pack.mSize); - K_pack.assign(bufferK.data()); + auto bufferK = jblas::utils::amalloc(K_pack.mSize); + K_pack.assign(bufferK); PackedWeightBatch V_pack(jblas::gemm::GemmCoreType::AMX_BF16_16x64); // packed V V_pack.resize(padto(p.head_size, GemmPV::NTILE), padto(p.sl_kv, GemmPV::KTILE), num_heads); - jblas::utils::avector bufferV(V_pack.mSize); - V_pack.assign(bufferV.data()); + auto bufferV = jblas::utils::amalloc(V_pack.mSize); + V_pack.assign(bufferV); const auto K_pack_batch_off = K_pack.mKPad * K_pack.mNPad; const auto V_pack_batch_off = V_pack.mKPad * V_pack.mNPad; @@ -853,6 +853,8 @@ class MHAInterface { } } } + jblas::utils::afree(bufferK); + jblas::utils::afree(bufferV); return JblasSuccess; }