Skip to content
This repository has been archived by the owner on Oct 25, 2024. It is now read-only.

Commit

Permalink
[CPP Graph] Opt qbits dequant (#465)
Browse files Browse the repository at this point in the history
  • Loading branch information
zhewang1-intc authored Oct 19, 2023
1 parent 1ab6ce3 commit f04d0fd
Show file tree
Hide file tree
Showing 15 changed files with 542 additions and 703 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -137,13 +137,19 @@ template <JBLAS_ISA ISA_T>
class ZpDequantInt32ToFp32 {
public:
struct Param {
// necessary
float* C;
int ldc;
uint8_t* zpA;
float* scalesA;
int ldsa;
float* reduceB;
float* scalesA;
float* scalesB;
// optional if A asym
uint8_t* zpA = nullptr;
float* reduceB = nullptr;
// optional if B asym
int8_t* zpB = nullptr;
float* reduceA = nullptr;
int K = 1;
};
JBLAS_CODE forward(const int32_t* cacheptr, const int cachestep, const int M_offset, const int N_offset, const int M,
const int N, const Param& _param) {
Expand All @@ -155,9 +161,22 @@ class ZpDequantInt32ToFp32 {
if (ret != JblasSuccess) {
return ret;
}
ret = kernel::wrapper::RemoveZeroPointBias::template forward<ISA_T>(
cptr, _param.ldc, M, N, _param.zpA + M_offset * _param.ldsa, _param.scalesA + M_offset * _param.ldsa,
_param.ldsa, _param.reduceB + N_offset);
if (_param.zpA == nullptr && _param.zpB == nullptr) {
return ret;
} else if (_param.zpA != nullptr && _param.zpB == nullptr) {
ret = kernel::wrapper::RemoveZeroPointBias::template forward<ISA_T>(
cptr, _param.ldc, M, N, _param.zpA + M_offset * _param.ldsa, _param.scalesA + M_offset * _param.ldsa,
_param.ldsa, _param.reduceB + N_offset);
} else if (_param.zpA == nullptr && _param.zpB != nullptr) {
ret = kernel::wrapper::RemoveZeroPointBias::template forward<ISA_T>(cptr, _param.ldc, M, N, _param.zpB + N_offset,
_param.scalesB + N_offset, _param.ldsa,
_param.reduceA + M_offset * _param.ldsa);
} else {
ret = kernel::wrapper::RemoveZeroPointBias::template forward<ISA_T>(
cptr, _param.ldc, M, N, _param.zpA + M_offset * _param.ldsa, _param.zpB + N_offset,
_param.scalesA + M_offset * _param.ldsa, _param.scalesB + N_offset, _param.ldsa, _param.K,
_param.reduceA + M_offset * _param.ldsa, _param.reduceB + N_offset);
}
return ret;
}
};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -819,9 +819,10 @@ class WeightPack {
}

void packWeightTranspose(const int N, const int K, const Param& _param) {
utils::aligned_vector<WType> B_NT(N * K);
transposeWeight<WType, ISA_T>(N, K, _param.B, _param.ldb, B_NT.data(), N);
return packWeight(N, K, {B_NT.data(), N, _param.packedW});
auto B_NT = utils::amalloc<WType>((size_t)N * K);
transposeWeight<WType, ISA_T>(N, K, _param.B, _param.ldb, B_NT, N);
packWeight(N, K, {B_NT, N, _param.packedW});
utils::afree(B_NT);
}

// from KxN int8 symmetric weight to packed N//NtilexKPadxNTile int4 weight
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,21 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#ifdef _OPENMP
#include <omp.h>
#endif

#include <algorithm>
#include <chrono>
#include <cmath>
#include <cstring>
#include <functional>
#include <vector>
#ifdef _WIN32
#include <cstdlib>
#else
#include <stdlib.h>
#endif

#include "jit_blas.h"
#include "xbyak/xbyak_util.h"
Expand Down Expand Up @@ -279,7 +286,7 @@ static inline _DSTT cast(_SRCT _src) {

template <>
int8_t cast(float _src) {
_src = _src >= 0.f ? _src + 0.5f : _src - 0.5f;
_src = roundf(_src);
_src = std::min(_src, 127.f);
_src = std::max(_src, -128.f);
return static_cast<int8_t>(_src);
Expand Down Expand Up @@ -320,11 +327,36 @@ _T deserialize(int8_t*& buf) {
static inline int padto(int a, int b) { return updiv(a, b) * b; }
static inline size_t padto(size_t a, int b) { return updiv(a, b) * b; }

template <typename _T>
static inline _T* amalloc(size_t _size, size_t _alignment = 64) {
if (_size == 0) {
return NULL;
}
auto psize = padto(_size * sizeof(_T), _alignment);
#ifdef _WIN32
return (_T*)_aligned_malloc(psize, _alignment);
#else
return (_T*)aligned_alloc(_alignment, psize);
#endif
}

static inline void afree(void* ptr) {
if (ptr == NULL) {
return;
}
#ifdef _WIN32
_aligned_free(ptr);
#else
free(ptr);
#endif
}

template <typename _T, int _Alignment = 64>
class aligned_vector {
public:
aligned_vector() : mRawsize(0), mPtr(nullptr), mAlignedsize(0) {}
aligned_vector(size_t _size, _T _val = _T(0)) {
aligned_vector(size_t _size) { resize(_size); }
aligned_vector(size_t _size, _T _val) {
resize(_size);
std::fill_n(mVec.begin(), mVec.size(), _val);
}
Expand Down Expand Up @@ -502,12 +534,15 @@ class CpuDevice {
ADD_FLAG(AVX512_BF16);
ADD_FLAG(AVX512_FP16);
numcores = _cpu.getNumCores(Xbyak::util::IntelCpuTopologyLevel::CoreLevel);
#ifdef _OPENMP
ompthreads = omp_get_max_threads();
numthreads = std::min(numcores, ompthreads);
#ifdef FORCE_NUM_THREADS
numthreads = FORCE_NUM_THREADS;
#else
ompthreads = numcores;
#endif
numthreads = std::min(numcores, ompthreads);
#ifdef _OPENMP
omp_set_num_threads(numthreads);
#endif
}

static CpuDevice* getInstance() {
Expand Down
Loading

0 comments on commit f04d0fd

Please sign in to comment.