Skip to content

Commit

Permalink
[not for land] hook up MX to CUDA 12.8 cuBLAS MX gemm
Browse files Browse the repository at this point in the history
Summary:

Requires https://github.com/pytorch/pytorch/pull/145562/files

None of this is for land - just testing for now as we work on a long
term support plan.

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:

ghstack-source-id: 9d7e30784817d9d89dcb98e6a33067928fa5db11
ghstack-comment-id: 2616580142
Pull Request resolved: #1625
  • Loading branch information
vkuzo committed Jan 27, 2025
1 parent f22207b commit eb12c3e
Show file tree
Hide file tree
Showing 2 changed files with 208 additions and 12 deletions.
127 changes: 115 additions & 12 deletions benchmarks/float8/bench_matmul.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
#
# This source code is licensed under the BSD 3-Clause license found in the
# LICENSE file in the root directory of this source tree.
from enum import IntEnum
import itertools
from typing import Optional

Expand All @@ -26,14 +27,44 @@
h100_peak_flops_fp16_tc = 989e12
h100_peak_tops_float8_tc = 1979e12

dtype_to_peak_tops = {
# HGX B20 specs: https://www.nvidia.com/en-us/data-center/hgx/
# note: divided numbers from ^ by 2 to undo the effects of sparsity
# TODO(this PR): I'm achieving 5% of peak TFLOPS with bf16 and float8,
# something seems funky
b200_peak_flops_float32 = 600e12
b200_peak_flops_fp16_tc = 18e15
b200_peak_tops_float8_tc = 36e15
b200_peak_tops_float4_tc = 72e15

dtype_to_peak_tops_h100 = {
torch.float32: h100_peak_flops_float32,
torch.float16: h100_peak_flops_fp16_tc,
torch.bfloat16: h100_peak_flops_fp16_tc,
torch.float8_e4m3fn: h100_peak_tops_float8_tc,
torch.float8_e5m2: h100_peak_tops_float8_tc,
}

dtype_to_peak_tops_b200 = {
torch.float32: b200_peak_flops_float32,
torch.float16: b200_peak_flops_fp16_tc,
torch.bfloat16: b200_peak_flops_fp16_tc,
torch.float8_e4m3fn: b200_peak_tops_float8_tc,
torch.float8_e5m2: b200_peak_tops_float8_tc,
# TODO float4
}

# TODO(this PR): switch automatically by detected hardware type
# TODO(this PR): fp4 is currently using fp8's peak tops below, fix it
dtype_to_peak_tops = dtype_to_peak_tops_b200


# not for land, matching https://www.internalfb.com/phabricator/paste/view/P1717686991
class DataType(IntEnum):
DEFAULT = 0
E8M0 = 1
FP4 = 2
UFP8 = 3


def benchmark_fn_in_sec(f, *args, **kwargs):
# Manual warmup
Expand Down Expand Up @@ -75,6 +106,7 @@ def run(
N: Optional[int] = None,
use_gpu_kernel_time: bool = False,
scaling_granularity: str = "tensorwise",
blockwise_dtype: Optional[str] = None,
):
device = "cuda"

Expand All @@ -85,15 +117,17 @@ def run(
"K",
"N",
"ref_time_s",
"fp8_time_s",
"fp8_speedup",
"lowp_time_s",
"lowp_speedup",
)
results = []

dtype = torch.bfloat16
name_to_shapes = get_name_to_shapes_iter(shape_gen_name, M, K, N)
fast_accum_vals = [True, False]
scaling_granularity = ScalingGranularity(scaling_granularity)
# Note: blockwise not in enum because blockwise is in prototype
if scaling_granularity != "blockwise":
scaling_granularity = ScalingGranularity(scaling_granularity)

for idx, (fast_accum, (name, (M, K, N))) in enumerate(
itertools.product(fast_accum_vals, name_to_shapes)
Expand All @@ -119,28 +153,97 @@ def run(
# raw float8 matmul (upper bound for what we can achive in eager mode)
# TODO(future): add e5m2
d1, d2, d3 = torch.float8_e4m3fn, torch.float8_e4m3fn, dtype
A = torch.zeros(M, K, device=device, dtype=d1)
B = torch.zeros(K, N, device=device, dtype=d2).t().contiguous().t()
A = torch.randn(M, K, device=device).to(d1)
B = torch.randn(K, N, device=device).to(d2).t().contiguous().t()
if scaling_granularity == ScalingGranularity.TENSORWISE:
scale_a = torch.tensor([1.0], device=device)
scale_b = torch.tensor([1.0], device=device)
else:
assert scaling_granularity == ScalingGranularity.AXISWISE, "unsupported"
elif scaling_granularity == ScalingGranularity.AXISWISE:
scale_a = torch.ones(M, 1, device=device)
scale_b = torch.ones(1, N, device=device)
elif scaling_granularity == "blockwise" and blockwise_dtype == "float8_e4m3":
# TODO(this PR): also block size 16
BLOCK_SIZE = 32
A = torch.randint(128, (M, K), device=device, dtype=torch.uint8).view(
torch.float8_e4m3fn
)
B = (
torch.randint(128, (N, K), device=device, dtype=torch.uint8)
.view(torch.float8_e4m3fn)
.t()
)
scale_a = torch.randint(
128, (M, K // BLOCK_SIZE), dtype=torch.uint8, device="cuda"
)
scale_b = torch.randint(
128, (N, K // BLOCK_SIZE), dtype=torch.uint8, device="cuda"
).t()
elif scaling_granularity == "blockwise" and blockwise_dtype == "float4":
# TODO(this PR): also block size 16
BLOCK_SIZE = 16
A = torch.randint(128, (M, K // 2), device=device, dtype=torch.uint8).view(
torch.float8_e4m3fn
)
B = (
torch.randint(128, (N, K // 2), device=device, dtype=torch.uint8)
.view(torch.float8_e4m3fn)
.t()
)
scale_a = torch.randint(
128, (M, K // BLOCK_SIZE), dtype=torch.uint8, device="cuda"
)
scale_b = torch.randint(
128, (N, K // BLOCK_SIZE), dtype=torch.uint8, device="cuda"
).t()
else:
raise AssertionError(f"unsupported granularity {scaling_granularity}")

def do_matmul(A, B):
nonlocal scale_a
nonlocal scale_b
return torch._scaled_mm(
A, B, scale_a, scale_b, out_dtype=d3, use_fast_accum=fast_accum
)

if scaling_granularity == "blockwise" and blockwise_dtype == "float8_e4m3":
return torch._scaled_mm(
A,
B,
scale_a,
scale_b,
bias=None,
scale_result=None,
out_dtype=d3,
use_fast_accum=fast_accum,
a_dtype=None, # inferred from A
b_dtype=None, # inferred from B
scale_dtype=DataType.E8M0,
)
elif scaling_granularity == "blockwise" and blockwise_dtype == "float4":
return torch._scaled_mm(
A,
B,
scale_a,
scale_b,
bias=None,
scale_result=None,
out_dtype=d3,
use_fast_accum=fast_accum,
a_dtype=DataType.FP4,
b_dtype=DataType.FP4,
scale_dtype=DataType.E8M0,
)

else:
return torch._scaled_mm(
A, B, scale_a, scale_b, out_dtype=d3, use_fast_accum=fast_accum
)

# test
# res = do_matmul(A, B)

fp8_time_sec, fp8_tops_sec, fp8_pct_top_peak = do_benchmarks(
tops, dtype_to_peak_tops[d1], use_gpu_kernel_time, do_matmul, A, B
)
print(
f"fp8 time_sec {fp8_time_sec:.2E}, tops/sec {fp8_tops_sec:.2E}, pct_peak {fp8_pct_top_peak:.3f}"
f"lowp time_sec {fp8_time_sec:.2E}, tops/sec {fp8_tops_sec:.2E}, pct_peak {fp8_pct_top_peak:.3f}"
)

del A, B, scale_a, scale_b
Expand Down
93 changes: 93 additions & 0 deletions test/prototype/mx_formats/test_mx_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
# LICENSE file in the root directory of this source tree.

import copy
from enum import IntEnum

import pytest
import torch
Expand All @@ -30,6 +31,14 @@
pytest.skip("Unsupported PyTorch version", allow_module_level=True)


# not for land, https://www.internalfb.com/phabricator/paste/view/P1717686991
class DataType(IntEnum):
DEFAULT = 0
E8M0 = 1
FP4 = 2
UFP8 = 3


# source: https://stackoverflow.com/a/22638709
@pytest.fixture(autouse=True)
def run_around_tests():
Expand Down Expand Up @@ -234,3 +243,87 @@ def test_filter_fn():
swap_linear_with_mx_inference_linear(m2, torch.float8_e4m3fn, 32, filter_fn) # noqa: E501
assert type(m2[0]) == MXInferenceLinear
assert type(m2[1]) == torch.nn.Linear


@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
@pytest.mark.skipif(
not is_sm_at_least_100(), reason="blockwise torch._scaled_mm requires CUDA 10.0 or higher"
)
def test_scaled_mm_mxfp8():
# hello world
# next: basic numerics

M, K, N = 8192, 4096, 8192
BLOCK_SIZE = 32
a = torch.randint(128, (M, K), device="cuda", dtype=torch.uint8).view(
torch.float8_e4m3fn
)
b = (
torch.randint(128, (N, K), device="cuda", dtype=torch.uint8)
.view(torch.float8_e4m3fn)
.t()
)
a_scales = torch.randint(
128, ((M * K) // BLOCK_SIZE,), device="cuda", dtype=torch.uint8
).view(M, K // BLOCK_SIZE)
b_scales = (
torch.randint(128, ((K * N) // BLOCK_SIZE,), device="cuda", dtype=torch.uint8)
.view(N, K // BLOCK_SIZE)
.t()
)
out = torch._scaled_mm(
a,
b,
a_scales,
b_scales,
None,
None,
torch.bfloat16,
False,
None,
None,
DataType.E8M0,
)
print(out)


@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
@pytest.mark.skipif(
not is_sm_at_least_100(), reason="blockwise torch._scaled_mm requires CUDA 10.0 or higher"
)
def test_scaled_mm_nvfp4():
# hello world
# next: basic numerics

M, K, N = 8192, 4096, 8192
BLOCK_SIZE = 16
a = torch.randint(128, ((M * K) // 2,), device="cuda", dtype=torch.uint8).view(
M, K // 2
)
b = (
torch.randint(128, ((K * N) // 2,), device="cuda", dtype=torch.uint8)
.view(N, K // 2)
.t()
)
a_scales = torch.randint(
128, ((M * K) // BLOCK_SIZE,), device="cuda", dtype=torch.uint8
).view(M, K // BLOCK_SIZE)
b_scales = (
torch.randint(128, ((K * N) // BLOCK_SIZE,), device="cuda", dtype=torch.uint8)
.view(N, K // BLOCK_SIZE)
.t()
)
out = torch._scaled_mm(
a,
b,
a_scales,
b_scales,
None,
None,
torch.bfloat16,
False,
DataType.FP4,
DataType.FP4,
DataType.UFP8,
)
print(out)

0 comments on commit eb12c3e

Please sign in to comment.