Skip to content

Commit

Permalink
integrate blockwise fp8 kernel (#3529)
Browse files Browse the repository at this point in the history
  • Loading branch information
yizhang2077 authored Feb 12, 2025
1 parent 4430c0a commit 98eecbd
Show file tree
Hide file tree
Showing 3 changed files with 124 additions and 23 deletions.
2 changes: 1 addition & 1 deletion python/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ runtime_common = [
]
srt = [
"sglang[runtime_common]", "cuda-python",
"sgl-kernel>=0.0.3.post4", "torch", "vllm>=0.6.4.post1,<=0.7.2",
"sgl-kernel>=0.0.3.post5", "torch", "vllm>=0.6.4.post1,<=0.7.2",
"flashinfer_python>=0.2.0.post2", "outlines>=0.0.44,<=0.1.11"
]

Expand Down
108 changes: 90 additions & 18 deletions python/sglang/srt/layers/quantization/fp8_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,11 +76,60 @@ def _per_token_group_quant_fp8(
tl.store(y_s_ptr, y_s)


@triton.jit
def _per_token_group_quant_fp8_colmajor(
# Pointers to inputs and output
y_ptr,
y_q_ptr,
y_s_ptr,
group_size,
# Num columns of y
y_num_columns,
# Stride from one column to the next of y_s
y_s_col_stride,
# Avoid to divide zero
eps,
# Information for float8
fp8_min,
fp8_max,
# Meta-parameters
BLOCK: tl.constexpr,
):
"""A Triton-accelerated function to perform per-token-group
quantization on a tensor.
This function converts the tensor values into float8 values.
"""
# Map the program id to the row of X and Y it should compute.
g_id = tl.program_id(0)
y_ptr += g_id * group_size
y_q_ptr += g_id * group_size

# Convert g_id the flattened block coordinate to 2D so we can index
# into the output y_scales matrix
blocks_per_row = y_num_columns // group_size
scale_col = g_id % blocks_per_row
scale_row = g_id // blocks_per_row
y_s_ptr += scale_col * y_s_col_stride + scale_row

cols = tl.arange(0, BLOCK) # group_size <= BLOCK
mask = cols < group_size

y = tl.load(y_ptr + cols, mask=mask, other=0.0).to(tl.float32)
# Quant
_absmax = tl.maximum(tl.max(tl.abs(y)), eps)
y_s = _absmax / fp8_max
y_q = tl.clamp(y / y_s, fp8_min, fp8_max).to(y_q_ptr.dtype.element_ty)

tl.store(y_q_ptr + cols, y_q, mask=mask)
tl.store(y_s_ptr, y_s)


def per_token_group_quant_fp8(
x: torch.Tensor,
group_size: int,
eps: float = 1e-10,
dtype: torch.dtype = fp8_type_,
column_major_scales: bool = False,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Function to perform per-token-group quantization on an input tensor `x`.
Expand Down Expand Up @@ -112,29 +161,52 @@ def per_token_group_quant_fp8(
x_q = torch.empty_like(x, device=x.device, dtype=dtype)
M = x.numel() // group_size
N = group_size
x_s = torch.empty(
x.shape[:-1] + (x.shape[-1] // group_size,),
device=x.device,
dtype=torch.float32,
)
if column_major_scales:
x_s = torch.empty(
(x.shape[-1] // group_size,) + x.shape[:-1],
device=x.device,
dtype=torch.float32,
).permute(-1, -2)
else:
x_s = torch.empty(
x.shape[:-1] + (x.shape[-1] // group_size,),
device=x.device,
dtype=torch.float32,
)

BLOCK = triton.next_power_of_2(N)
# heuristics for number of warps
num_warps = min(max(BLOCK // 256, 1), 8)
num_stages = 1
_per_token_group_quant_fp8[(M,)](
x,
x_q,
x_s,
group_size,
N,
eps,
fp8_min=fp8_min,
fp8_max=fp8_max,
BLOCK=BLOCK,
num_warps=num_warps,
num_stages=num_stages,
)
if column_major_scales:
_per_token_group_quant_fp8_colmajor[(M,)](
x,
x_q,
x_s,
group_size,
x.shape[1],
x_s.stride(1),
eps,
fp8_min=fp8_min,
fp8_max=fp8_max,
BLOCK=BLOCK,
num_warps=num_warps,
num_stages=num_stages,
)
else:
_per_token_group_quant_fp8[(M,)](
x,
x_q,
x_s,
group_size,
N,
eps,
fp8_min=fp8_min,
fp8_max=fp8_max,
BLOCK=BLOCK,
num_warps=num_warps,
num_stages=num_stages,
)

return x_q, x_s

Expand Down
37 changes: 33 additions & 4 deletions python/sglang/srt/layers/quantization/fp8_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,9 @@
from sglang.srt.utils import is_hip

is_hip_ = is_hip()
_is_cuda = torch.cuda.is_available() and torch.version.cuda
if _is_cuda:
from sgl_kernel import fp8_blockwise_scaled_mm


def normalize_e4m3fn_to_e4m3fnuz(
Expand All @@ -36,6 +39,19 @@ def normalize_e4m3fn_to_e4m3fnuz(
return weight, weight_scale, input_scale


def cutlass_block_fp8_supported() -> bool:
if _is_cuda:
major, minor = torch.cuda.get_device_capability()
sm_version = major * 10 + minor
cuda_version = tuple(map(int, torch.version.cuda.split(".")))
if cuda_version >= (12, 0) and sm_version >= 90:
return True
return False


CUTLASS_BLOCK_FP8_SUPPORTED = cutlass_block_fp8_supported()


def apply_w8a8_block_fp8_linear(
input: torch.Tensor,
weight: torch.Tensor,
Expand All @@ -48,11 +64,24 @@ def apply_w8a8_block_fp8_linear(
# View input as 2D matrix for fp8 methods
input_2d = input.view(-1, input.shape[-1])
output_shape = [*input.shape[:-1], weight.shape[0]]

q_input, x_scale = per_token_group_quant_fp8(input_2d, block_size[1])
output = w8a8_block_fp8_matmul(
q_input, weight, x_scale, weight_scale, block_size, output_dtype=input.dtype
# TODO: add more robust shape check here
shape_supported_by_cutlass = (
weight.shape[0] % 128 == 0 and weight.shape[1] % 128 == 0
)
if CUTLASS_BLOCK_FP8_SUPPORTED and shape_supported_by_cutlass:
q_input, x_scale = per_token_group_quant_fp8(
input_2d, block_size[1], column_major_scales=True
)
output = fp8_blockwise_scaled_mm(
q_input, weight.T, x_scale, weight_scale.T, out_dtype=input.dtype
)
else:
q_input, x_scale = per_token_group_quant_fp8(
input_2d, block_size[1], column_major_scales=False
)
output = w8a8_block_fp8_matmul(
q_input, weight, x_scale, weight_scale, block_size, output_dtype=input.dtype
)

if bias is not None:
output = output + bias
Expand Down

0 comments on commit 98eecbd

Please sign in to comment.