Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[not for land] hook up MX to CUDA 12.8 cuBLAS MX gemm #1625

Open
wants to merge 13 commits into
base: gh/vkuzo/21/head
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
127 changes: 115 additions & 12 deletions benchmarks/float8/bench_matmul.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
# This source code is licensed under the BSD 3-Clause license found in the
# LICENSE file in the root directory of this source tree.
import itertools
from enum import IntEnum
from typing import Optional

import fire
Expand All @@ -26,14 +27,44 @@
h100_peak_flops_fp16_tc = 989e12
h100_peak_tops_float8_tc = 1979e12

dtype_to_peak_tops = {
# HGX B20 specs: https://www.nvidia.com/en-us/data-center/hgx/
# note: divided numbers from ^ by 2 to undo the effects of sparsity
# TODO(this PR): I'm achieving 5% of peak TFLOPS with bf16 and float8,
# something seems funky
b200_peak_flops_float32 = 600e12
b200_peak_flops_fp16_tc = 18e15
b200_peak_tops_float8_tc = 36e15
b200_peak_tops_float4_tc = 72e15

dtype_to_peak_tops_h100 = {
torch.float32: h100_peak_flops_float32,
torch.float16: h100_peak_flops_fp16_tc,
torch.bfloat16: h100_peak_flops_fp16_tc,
torch.float8_e4m3fn: h100_peak_tops_float8_tc,
torch.float8_e5m2: h100_peak_tops_float8_tc,
}

dtype_to_peak_tops_b200 = {
torch.float32: b200_peak_flops_float32,
torch.float16: b200_peak_flops_fp16_tc,
torch.bfloat16: b200_peak_flops_fp16_tc,
torch.float8_e4m3fn: b200_peak_tops_float8_tc,
torch.float8_e5m2: b200_peak_tops_float8_tc,
# TODO float4
}

# TODO(this PR): switch automatically by detected hardware type
# TODO(this PR): fp4 is currently using fp8's peak tops below, fix it
dtype_to_peak_tops = dtype_to_peak_tops_b200


# not for land, matching https://www.internalfb.com/phabricator/paste/view/P1717686991
class DataType(IntEnum):
DEFAULT = 0
E8M0 = 1
FP4 = 2
UFP8 = 3


def benchmark_fn_in_sec(f, *args, **kwargs):
# Manual warmup
Expand Down Expand Up @@ -75,6 +106,7 @@ def run(
N: Optional[int] = None,
use_gpu_kernel_time: bool = False,
scaling_granularity: str = "tensorwise",
blockwise_dtype: Optional[str] = None,
):
device = "cuda"

Expand All @@ -85,15 +117,17 @@ def run(
"K",
"N",
"ref_time_s",
"fp8_time_s",
"fp8_speedup",
"lowp_time_s",
"lowp_speedup",
)
results = []

dtype = torch.bfloat16
name_to_shapes = get_name_to_shapes_iter(shape_gen_name, M, K, N)
fast_accum_vals = [True, False]
scaling_granularity = ScalingGranularity(scaling_granularity)
# Note: blockwise not in enum because blockwise is in prototype
if scaling_granularity != "blockwise":
scaling_granularity = ScalingGranularity(scaling_granularity)

for idx, (fast_accum, (name, (M, K, N))) in enumerate(
itertools.product(fast_accum_vals, name_to_shapes)
Expand All @@ -119,28 +153,97 @@ def run(
# raw float8 matmul (upper bound for what we can achive in eager mode)
# TODO(future): add e5m2
d1, d2, d3 = torch.float8_e4m3fn, torch.float8_e4m3fn, dtype
A = torch.zeros(M, K, device=device, dtype=d1)
B = torch.zeros(K, N, device=device, dtype=d2).t().contiguous().t()
A = torch.randn(M, K, device=device).to(d1)
B = torch.randn(K, N, device=device).to(d2).t().contiguous().t()
if scaling_granularity == ScalingGranularity.TENSORWISE:
scale_a = torch.tensor([1.0], device=device)
scale_b = torch.tensor([1.0], device=device)
else:
assert scaling_granularity == ScalingGranularity.AXISWISE, "unsupported"
elif scaling_granularity == ScalingGranularity.AXISWISE:
scale_a = torch.ones(M, 1, device=device)
scale_b = torch.ones(1, N, device=device)
elif scaling_granularity == "blockwise" and blockwise_dtype == "float8_e4m3":
# TODO(this PR): also block size 16
BLOCK_SIZE = 32
A = torch.randint(128, (M, K), device=device, dtype=torch.uint8).view(
torch.float8_e4m3fn
)
B = (
torch.randint(128, (N, K), device=device, dtype=torch.uint8)
.view(torch.float8_e4m3fn)
.t()
)
scale_a = torch.randint(
128, (M, K // BLOCK_SIZE), dtype=torch.uint8, device="cuda"
)
scale_b = torch.randint(
128, (N, K // BLOCK_SIZE), dtype=torch.uint8, device="cuda"
).t()
elif scaling_granularity == "blockwise" and blockwise_dtype == "float4":
# TODO(this PR): also block size 16
BLOCK_SIZE = 16
A = torch.randint(128, (M, K // 2), device=device, dtype=torch.uint8).view(
torch.float8_e4m3fn
)
B = (
torch.randint(128, (N, K // 2), device=device, dtype=torch.uint8)
.view(torch.float8_e4m3fn)
.t()
)
scale_a = torch.randint(
128, (M, K // BLOCK_SIZE), dtype=torch.uint8, device="cuda"
)
scale_b = torch.randint(
128, (N, K // BLOCK_SIZE), dtype=torch.uint8, device="cuda"
).t()
else:
raise AssertionError(f"unsupported granularity {scaling_granularity}")

def do_matmul(A, B):
nonlocal scale_a
nonlocal scale_b
return torch._scaled_mm(
A, B, scale_a, scale_b, out_dtype=d3, use_fast_accum=fast_accum
)

if scaling_granularity == "blockwise" and blockwise_dtype == "float8_e4m3":
return torch._scaled_mm(
A,
B,
scale_a,
scale_b,
bias=None,
scale_result=None,
out_dtype=d3,
use_fast_accum=fast_accum,
a_dtype=None, # inferred from A
b_dtype=None, # inferred from B
scale_dtype=DataType.E8M0,
)
elif scaling_granularity == "blockwise" and blockwise_dtype == "float4":
return torch._scaled_mm(
A,
B,
scale_a,
scale_b,
bias=None,
scale_result=None,
out_dtype=d3,
use_fast_accum=fast_accum,
a_dtype=DataType.FP4,
b_dtype=DataType.FP4,
scale_dtype=DataType.E8M0,
)

else:
return torch._scaled_mm(
A, B, scale_a, scale_b, out_dtype=d3, use_fast_accum=fast_accum
)

# test
# res = do_matmul(A, B)

fp8_time_sec, fp8_tops_sec, fp8_pct_top_peak = do_benchmarks(
tops, dtype_to_peak_tops[d1], use_gpu_kernel_time, do_matmul, A, B
)
print(
f"fp8 time_sec {fp8_time_sec:.2E}, tops/sec {fp8_tops_sec:.2E}, pct_peak {fp8_pct_top_peak:.3f}"
f"lowp time_sec {fp8_time_sec:.2E}, tops/sec {fp8_tops_sec:.2E}, pct_peak {fp8_pct_top_peak:.3f}"
)

del A, B, scale_a, scale_b
Expand Down
Loading
Loading