Our kernels are based on x64 template library BESTLA.
Limited by the graph framework, we only add kernels which accept float tensor as input and output tensor.
input dtype | output dtype | compute type | compute ISA |
---|---|---|---|
float32 | float32 | float32 | AVX2 |
float32 | float32 | float32 | AVX512F |
float321 | float322 | int8 | AVX512_VNNI |
float321 | float322 | int8 | AVX_VNNI |
float321 | float322 | int8 | AMX_INT8 |
float32/bf16 | float32/bf16 | bf16 | AMX_BF16 |
float32/fp16 | float32/fp16 | fp16 | AVX512_FP16 |
1: per-batch and per-K group-wise dynamic quantization for input tensor, where per-K group-wise also applies to weight quantization group size of weight tensor; support both symmetric and asymmetric quantization. 2: per-batch dynamic dequantization for output tensor.
dtype | algo | group size |
---|---|---|
int4 | symmetric int8 truncated quant2 | multiplier of 8, -11 |
int4 | symmetric int4 full range3 | multiplier of 8, -11 |
int4 | asymmetric int4 full range3 | multiplier of 8, -11 |
int8 | symmetric | multiplier of 8, -11 |
fp4 | multiplier of 8 | |
nf4 | multiplier of 8 |
1: group size=-1 means per channel quantization on output channel (or group size equals to input channel size). 2: truncated quant means keep the high 4 bits of int8 quantization result for model saving and computation. 3: full range is a quantization method that utilizes the -8 value of int4 range compared with the normal int4 range [-7,7].
NOTE: AMX_INT8 requires group size is aligend to 128 (best hardware efficiency)
We support three kinds of kernel fusion for transformer models: QKV, MHA (multi-head attention), and FFN (feed-forward network) fusion.
fusion type | models | runtime ISA |
---|---|---|
QKV | GPT-J LLaMA |
AMX_INT8, AVX512_VNNI, AVX_VNNI |
FFN | GPT-J LLaMA BLOOM ChatGLM Falcon MPT |
AMX_INT8, AVX512_VNNI, AVX512F, AMX_BF16, AVX_VNNI, AVX2 |
MHA |
Referring the fused-attention doc for details |
codename | weight config | runtime ISA |
---|---|---|
Sapphire Rapids | any int4 group size=-1 compute type=int8 |
AMX_INT8 |
Ice Lake Cascade Lake Cooper Lake Tiger Lake Rocket Lake |
any int4 group size=-1 compute type=int8 |
AVX512_VNNI |
Skylake | any 4bits group size=-1 compute type=fp32 |
AVX512F |
Alder Lake (12th Gen) Raptor Lake (13th and 14th Gen) |
any 4bits group size=-1 compute type=int8 |
AVX_VNNI |
Older architecture (before 12th Gen) | any 4bits group size=-1 compute type=fp32 |
AVX2 |