Skip to content

mobiusml/gemlite

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

GemLite

Triton Kernels for Efficient Low-Bit Matrix Multiplication

GemLite Logo

Twitter

Made with ❤ by the team at Mobius Labs for 'Aana' (ആന : Elephant) suite of multimodal product.

GemLite is a collection of Triton kernels designed for efficient low-bit matrix multiplication, emphasizing simplicity and reusability. It provides a practical solution for achieving significant performance gains, delivering up to 7-8x faster prefill and 3-6x faster decoding compared to default Torch AO kernels. For more detailed benchmarks, check the Performance section.

GemLite strikes the perfect balance between flexibility and performance, allowing users to easily use and modify the codebase to develop high-performance kernels optimized for their specific hardware. We have included multiple versions of the kernels to maximize performance across different matrix shapes.

The project started with CUDA kernels, but we have switched to Triton for enhanced flexibility. For the old CUDA version, please refer to this branch.

Result Teaser

End-to-end Performance (Llama3 8-bit) Matmul Performance (A16W8)
End to End Performance Matmul Performance

Extensive performance results across different bitwidths, batch sizes, and devices are available in the Performance section below.

Table of Contents

Recent Highlights

  • Major performance improvement: especially on the A100 and H100.
  • Flexible bitpacking: use 8-bit packing for improved batched performance on the A100 and H100 with packed data.
  • Autotune caching: save/load the best autotune configs across all the kernels with a single line of code.
  • Helper functions: helper functions make it easier to get started, especially useful for dynamic quantization.
  • New GEMV RevSplitK algorithm: outperforms GEMM Split-K and GEMV for batch-size=1 with packed data.
  • Channel-wise scaling: Added support for channel-wise scaling for weights, activations, and both.
  • Precision support: Includes FP16 x Wn, FP8 x FP8, FP8 x Wn, INT8 x INT8 and INT8 x Wn.
  • torch.compile() support.

Getting Started

Installation

Latest Stable Version

pip install gemlite

Latest

pip install git+https://github.com/mobiusml/gemlite/

Usage

from gemlite.core import DType, GemLiteLinear

#Currently using the Triton backend as the default
gemlite_linear = GemLiteLinear(
    W_nbits, #weight quantization bitwidth. supported: [8, 4, 2, 1]
    group_size=group_size, # any group_size divisible by 32 - enable autotune for group_size < 128 (!)
    in_features=in_features, # input size
    out_features=out_features, #ouput size
    input_dtype=DType.FP16, #FP16, FP8, INT8
    output_dtype=DType.FP16, #FP16, FP32, FP8, INT32
    scaled_activations=False, #If the activations are scaled or not
)

#Packing: we follow the same format as hqq (https://github.com/mobiusml/hqq/)
gemlite_linear.pack(W_q, scales, zeros, bias)

#For activation quantization you need to override this function which should return the activation scales:
#gemlite_linear.scale_activations = f(x: torch.Tensor) -> x_scaled: torch.Tensor, scales: torch.Tensor # x ~ x_scaled * scaled

#Forward
out = gemlite_linear(x)

Additionally, we offer helper functions that operate as follows:

from gemlite.helper import *

#Non-packed 8-bit weights (INT8 or FP8)
gemlite_linear = A16W8(device='cuda:0').from_linear(linear_layer) #FP16 activations
gemlite_linear = A8W8_int8_dynamic(device='cuda:0').from_linear(linear_layer) #INT8 activations
gemlite_linear = A8W8_fp8_dynamic(device='cuda:0').from_linear(linear_layer) #FP8 activations

#Packed weights for 4-bit/2-bit/1-bit (HQQ format)
gemlite_linear = A16Wn(device='cuda:0').from_hqqlinear(hqqlinear_layer) #FP16 activations
gemlite_linear = A8Wn_dynamic(device='cuda:0').from_hqqlinear(hqqlinear_layer) #FP8 activations

Triton autotuning can be time-consuming. To accelerate this process, we provide tools to automatically cache and load the optimal autotuning configurations for all kernels:

GemLiteLinear.cache_config('a100_config.json') #Cache- run this over multiple batch-sizes
GemLiteLinear.load_config('a100_config.json') #Load

Ensure that you have one JSON cache file per GPU model. When the cache is loaded, the kernels will skip autotuning, leading to a faster startup time.

Deep Dive

We implement various versions of the Triton kernels:

  • GEMV: This GEMV kernel splits the activations into 1D chunks, performs the dot product using tl.sum, and accumulates via atomic addition. It is primarily intended for use with small batch sizes (M < 16). As tl.atomic_add does not support bfloat16, this kernel is limited to float16.

  • GEMM: This GEMM kernel is implemented similarly to GPTQ-triton. Since it uses tensor cores, activations must be padded with zeros along the batch dimension to fit at least 16 rows. It supports both float32 and float16 accumulation for fp16 inputs, but only float32 accumulation for bfloat16.

  • GEMM Split-K: This Split-K GEMM kernel is implemented similarly to the gptq Split-K version. We build on the gemm version above and add another dimension in the grid which splits the K dimension into multiple jobs that calculate partial sums, which are atomically added and finally stored. Split-K performs particularly well for batched LLM decoding (batch-size between 1 and 32).

  • Gemv RevSplit-K: This newly proposed algorithm in GemLite operates in contrast to the GEMM Split-K approach, but within a GEMV context. By doubling the workload per Triton program launched in the GEMV kernel, it reduces the frequency of loading scales/zeros and lowers the number of threads needed. As a result, this method delivers the best performance for batch-size=1 decoding.

All kernels are flexible, supporting 8, 4, 2, and 1-bit weight precisions as well as both fp16 and int8/fp8 activations.

Performance

End-2-End Performance

We present various end-2-end Llama results generated with gptfast. GemLite leads to up to 7-8x faster prefill and 3-6x faster decoding compared to the default torchao kernels:

llama3_8bit.svg
llama3_8bit_dynamic.svg
llama3_4bit.svg
llama2_prefill.svg

Matmul Performance

We present performance results across various batch sizes on the RTX 4090. Performance is measured as the speed-up relative to A16W16 (fp16 torch.matmul). You can reproduce these results by running examples/benchmark_triton.py after installing the necessary dependencies via install_dependencies.sh.

8-bit Weights
8bit_gs=infeatures_4096x4096_4090RTX
8bit_gs=infeatures_8192x8192_4090RTX
8bit_gs=infeatures_16384x16384_4090RTX
8bit_gs=infeatures_32768x32768_4090RTX
4-bit Weights
4bit_gs=128_4096x4096_4090RTX
4bit_gs=128_8192x8192_4090RTX
4bit_gs=128_16384x16384_4090RTX
4bit_gs=128_32768x32768_4090RTX
2-bit Weights
2bit_gs=128_4096x4096_4090RTX
2bit_gs=128_8192x8192_4090RTX
2bit_gs=128_16384x16384_4090RTX
2bit_gs=128_32768x32768_4090RTX

Talks and Resources

Check out the talk lead author Dr. Hicham Badri gave about GemLite at GPU MODE. You can also find the slides here.

Please note that GemLite is under active development, and the content discussed in the talk may evolve as the library continues to improve.

Contributing

Contributions are always welcome! Please feel free to raise issues, submit pull requests, or start a discussion.

If you're looking to integrate GemLite with major inference and AI libraries, we'd love to hear about it!