Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[QST] Is cutlass::bfloat16_t x cutlass::int2b_t GEMM possible? #1915

Open
areddy2022 opened this issue Nov 3, 2024 · 3 comments
Open

[QST] Is cutlass::bfloat16_t x cutlass::int2b_t GEMM possible? #1915

areddy2022 opened this issue Nov 3, 2024 · 3 comments

Comments

@areddy2022
Copy link

While looking at example 55 (cutlass/examples/55_hopper_mixed_dtype_gemm/55_hopper_int4_bf16_gemm.cu), I was curious whether this modification would be legal:

From:
using MmaType = cutlass::bfloat16_t; using QuantType = cutlass::int4b_t;

To:
using MmaType = cutlass::bfloat16_t; using QuantType = cutlass::int2b_t;

According to the README.md, for the example, "For 8-bit x 4-bit or 2-bit, both inputs must be K-major." However, the internal comment states, "Only supports INT4 x { FP16, BF16 }." Furthermore, I'm having trouble finding documentation in the library over the use of int2b_t datatype for use in GEMM. I apologize if this question needs to be more detailed or if I missed some part of the documentation.

@IwakuraRein
Copy link
Contributor

That comment is inaccurate and needs updating. This kernel always converts smaller dtypes to the MMA dtype with CUDA Cores, so as long as such conversion logic exists, the combination is legal. I just checked that INT2 x BF16 compiled successfully and passed the test.

However, for optimal performance, such conversion should be low-cost; otherwise, it may negate the benefits of using mixed dtypes. Please refer to the PTX doc to verify if the hardware natively supports this conversion, also this header to see if an optimized software solution is available.

Currently, only naive conversion logic exists for int2b_t, i.e., casting first to int and then to bfloat16_t. This process makes the kernel significantly slower compared to INT4 x BF16.

Please let us know if this combination is important for your use case, and let us see if we can implement an optimized conversion for it.

@areddy2022
Copy link
Author

This would be a beneficial feature to have for extremely low-bit quantization (e.g. Bitnet or HQQ).

If I were to try and implement the conversion myself, I would need to add code for numeric conversions in numeric_conversion.h?

@IwakuraRein
Copy link
Contributor

Yes. Feel free to implement it as a partial specialization of the NumericArrayConverter. Please refer to the INT4 => BF16 as example.

IMO, INT2 => BF16 would be essentially the same as INT4 => BF16, just move the INT2 to the lSB of the target BF16 (as part of the mantissa).

Also notice that even with offline layout swizzling, int2b_t cannot be packed to a register when loading from smem (int4b_t can be packed to satisfy LDS.32 and thus int2b_t can only satisfy LDS.16). We are currently experimenting improving this.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

2 participants