Skip to content

Commit

Permalink
Enable FP8 Triton dequantized block-wise kernel
Browse files Browse the repository at this point in the history
Summary: Enable FP8 Triton dequantized block-wise kernel, which is required to upcast with block-wise quantized all2all

Differential Revision: D70872110
  • Loading branch information
jiawenliu64 authored and facebook-github-bot committed Mar 10, 2025
1 parent 27724d9 commit f217899
Show file tree
Hide file tree
Showing 2 changed files with 85 additions and 0 deletions.
26 changes: 26 additions & 0 deletions fbgemm_gpu/experimental/gemm/test/fp8_gemm_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

if torch.cuda.is_available():
from fbgemm_gpu.experimental.gemm.triton_gemm.fp8_gemm import (
dequantize_fp8_block,
matmul_fp8_block,
matmul_fp8_row,
quantize_fp8_block,
Expand Down Expand Up @@ -274,6 +275,31 @@ def _test_quantize_fp8_block(
_test_quantize_fp8_block((3, 6), (2, 8))
_test_quantize_fp8_block((3, 6), (2, 8), use_scale_ub=True)

def test_dequantize_fp8_block(self) -> None:
def _test_dequantize_fp8_block(
shape: Tuple[int, int],
block_size: int,
use_scale_ub: bool = False,
) -> None:
M, K = shape
a = torch.randn(M, K, dtype=torch.bfloat16, device="cuda")

scale_ub = (
torch.tensor([1200], dtype=torch.float, device="cuda")
if use_scale_ub
else None
)

a_fp8, a_scale = quantize_fp8_block(
a, block_m=block_size, block_k=block_size, scale_ub=scale_ub
)
a_dequant = dequantize_fp8_block(a_fp8, a_scale, block_size=block_size)
self.assertTrue(torch.allclose(a, a_dequant, atol=2e-1, rtol=5e-2))

_test_dequantize_fp8_block((3, 1024), 128)
_test_dequantize_fp8_block((11, 128), 256)
_test_dequantize_fp8_block((11, 256), 256, use_scale_ub=True)

def test_matmul_fp8_block(self) -> None:
def _test_matmul_fp8_block(
shape: Tuple[int, int, int],
Expand Down
59 changes: 59 additions & 0 deletions fbgemm_gpu/experimental/gemm/triton_gemm/fp8_gemm.py
Original file line number Diff line number Diff line change
Expand Up @@ -3097,3 +3097,62 @@ def _kernel_matmul_fp8_row_non_persistent(
tl.store(C, acc, mask=mask)
else:
tl.atomic_add(C, acc, mask=mask)


@triton.jit
def _kernel_dequantize_fp8_block(
xq_ptr, x_scale_ptr, x_dequant_ptr, M, N, BLOCK_SIZE: tl.constexpr
):
"""
Dequantize FP8 tensor to BF16 tensor.
Args:
xq_ptr (tl.constexpr): Pointer to FP8 tensor.
x_scale_ptr (tl.constexpr): Pointer to FP8 scale tensor.
x_dequant_ptr (tl.constexpr): Pointer to BF16 tensor.
M (tl.constexpr): M dimension of input tensor.
N (tl.constexpr): N dimension of input tensor.
BLOCK_SIZE (tl.constexpr): Block size for reduction.
"""
pid_m = tl.program_id(axis=0)
pid_n = tl.program_id(axis=1)
n = tl.cdiv(N, BLOCK_SIZE)
offs_m = pid_m * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
offs_n = pid_n * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
offs = offs_m[:, None] * N + offs_n[None, :]
mask = (offs_m[:, None] < M) & (offs_n[None, :] < N)
xq = tl.load(xq_ptr + offs, mask=mask).to(tl.bfloat16)
x_scale = tl.load(x_scale_ptr + pid_m * n + pid_n)
x_dequant = xq * x_scale
tl.store(x_dequant_ptr + offs, x_dequant, mask=mask)


def dequantize_fp8_block(
xq: torch.Tensor, x_scale: torch.Tensor, block_size: int = 128
) -> torch.Tensor:
"""
Dequantize FP8 tensor to BF16 tensor.
Args:
xq (torch.Tensor): FP8 tensor to be dequantized.
x_scale (torch.Tensor): FP8 scale tensor.
block_size (int): Block size for dequantization.
Returns:
torch.Tensor: Dequantized BF16 tensor.
"""

assert (
xq.is_contiguous() and x_scale.is_contiguous()
), "Input tensors must be contiguous"
assert xq.dim() == 2 and x_scale.dim() == 2, "Input tensors must have 2 dimensions"
M, N = xq.size()
x_dequant = torch.empty_like(xq, dtype=torch.bfloat16)
grid = lambda meta: (
triton.cdiv(M, meta["BLOCK_SIZE"]),
triton.cdiv(N, meta["BLOCK_SIZE"]),
)
_kernel_dequantize_fp8_block[grid](
xq, x_scale, x_dequant, M, N, BLOCK_SIZE=block_size # pyre-ignore[6]
)
return x_dequant

0 comments on commit f217899

Please sign in to comment.