diff --git a/benchmarks/float8/bench_matmul.py b/benchmarks/float8/bench_matmul.py index 3d48853754..52cfcfc481 100644 --- a/benchmarks/float8/bench_matmul.py +++ b/benchmarks/float8/bench_matmul.py @@ -3,6 +3,7 @@ # # This source code is licensed under the BSD 3-Clause license found in the # LICENSE file in the root directory of this source tree. +from enum import IntEnum import itertools from typing import Optional @@ -26,7 +27,16 @@ h100_peak_flops_fp16_tc = 989e12 h100_peak_tops_float8_tc = 1979e12 -dtype_to_peak_tops = { +# HGX B20 specs: https://www.nvidia.com/en-us/data-center/hgx/ +# note: divided numbers from ^ by 2 to undo the effects of sparsity +# TODO(this PR): I'm achieving 5% of peak TFLOPS with bf16 and float8, +# something seems funky +b200_peak_flops_float32 = 600e12 +b200_peak_flops_fp16_tc = 18e15 +b200_peak_tops_float8_tc = 36e15 +b200_peak_tops_float4_tc = 72e15 + +dtype_to_peak_tops_h100 = { torch.float32: h100_peak_flops_float32, torch.float16: h100_peak_flops_fp16_tc, torch.bfloat16: h100_peak_flops_fp16_tc, @@ -34,6 +44,27 @@ torch.float8_e5m2: h100_peak_tops_float8_tc, } +dtype_to_peak_tops_b200 = { + torch.float32: b200_peak_flops_float32, + torch.float16: b200_peak_flops_fp16_tc, + torch.bfloat16: b200_peak_flops_fp16_tc, + torch.float8_e4m3fn: b200_peak_tops_float8_tc, + torch.float8_e5m2: b200_peak_tops_float8_tc, + # TODO float4 +} + +# TODO(this PR): switch automatically by detected hardware type +# TODO(this PR): fp4 is currently using fp8's peak tops below, fix it +dtype_to_peak_tops = dtype_to_peak_tops_b200 + + +# not for land, matching https://www.internalfb.com/phabricator/paste/view/P1717686991 +class DataType(IntEnum): + DEFAULT = 0 + E8M0 = 1 + FP4 = 2 + UFP8 = 3 + def benchmark_fn_in_sec(f, *args, **kwargs): # Manual warmup @@ -75,6 +106,7 @@ def run( N: Optional[int] = None, use_gpu_kernel_time: bool = False, scaling_granularity: str = "tensorwise", + blockwise_dtype: Optional[str] = None, ): device = "cuda" @@ -85,15 +117,17 @@ def run( "K", "N", "ref_time_s", - "fp8_time_s", - "fp8_speedup", + "lowp_time_s", + "lowp_speedup", ) results = [] dtype = torch.bfloat16 name_to_shapes = get_name_to_shapes_iter(shape_gen_name, M, K, N) fast_accum_vals = [True, False] - scaling_granularity = ScalingGranularity(scaling_granularity) + # Note: blockwise not in enum because blockwise is in prototype + if scaling_granularity != "blockwise": + scaling_granularity = ScalingGranularity(scaling_granularity) for idx, (fast_accum, (name, (M, K, N))) in enumerate( itertools.product(fast_accum_vals, name_to_shapes) @@ -119,28 +153,97 @@ def run( # raw float8 matmul (upper bound for what we can achive in eager mode) # TODO(future): add e5m2 d1, d2, d3 = torch.float8_e4m3fn, torch.float8_e4m3fn, dtype - A = torch.zeros(M, K, device=device, dtype=d1) - B = torch.zeros(K, N, device=device, dtype=d2).t().contiguous().t() + A = torch.randn(M, K, device=device).to(d1) + B = torch.randn(K, N, device=device).to(d2).t().contiguous().t() if scaling_granularity == ScalingGranularity.TENSORWISE: scale_a = torch.tensor([1.0], device=device) scale_b = torch.tensor([1.0], device=device) - else: - assert scaling_granularity == ScalingGranularity.AXISWISE, "unsupported" + elif scaling_granularity == ScalingGranularity.AXISWISE: scale_a = torch.ones(M, 1, device=device) scale_b = torch.ones(1, N, device=device) + elif scaling_granularity == "blockwise" and blockwise_dtype == "float8_e4m3": + # TODO(this PR): also block size 16 + BLOCK_SIZE = 32 + A = torch.randint(128, (M, K), device=device, dtype=torch.uint8).view( + torch.float8_e4m3fn + ) + B = ( + torch.randint(128, (N, K), device=device, dtype=torch.uint8) + .view(torch.float8_e4m3fn) + .t() + ) + scale_a = torch.randint( + 128, (M, K // BLOCK_SIZE), dtype=torch.uint8, device="cuda" + ) + scale_b = torch.randint( + 128, (N, K // BLOCK_SIZE), dtype=torch.uint8, device="cuda" + ).t() + elif scaling_granularity == "blockwise" and blockwise_dtype == "float4": + # TODO(this PR): also block size 16 + BLOCK_SIZE = 16 + A = torch.randint(128, (M, K // 2), device=device, dtype=torch.uint8).view( + torch.float8_e4m3fn + ) + B = ( + torch.randint(128, (N, K // 2), device=device, dtype=torch.uint8) + .view(torch.float8_e4m3fn) + .t() + ) + scale_a = torch.randint( + 128, (M, K // BLOCK_SIZE), dtype=torch.uint8, device="cuda" + ) + scale_b = torch.randint( + 128, (N, K // BLOCK_SIZE), dtype=torch.uint8, device="cuda" + ).t() + else: + raise AssertionError(f"unsupported granularity {scaling_granularity}") def do_matmul(A, B): nonlocal scale_a nonlocal scale_b - return torch._scaled_mm( - A, B, scale_a, scale_b, out_dtype=d3, use_fast_accum=fast_accum - ) + + if scaling_granularity == "blockwise" and blockwise_dtype == "float8_e4m3": + return torch._scaled_mm( + A, + B, + scale_a, + scale_b, + bias=None, + scale_result=None, + out_dtype=d3, + use_fast_accum=fast_accum, + a_dtype=None, # inferred from A + b_dtype=None, # inferred from B + scale_dtype=DataType.E8M0, + ) + elif scaling_granularity == "blockwise" and blockwise_dtype == "float4": + return torch._scaled_mm( + A, + B, + scale_a, + scale_b, + bias=None, + scale_result=None, + out_dtype=d3, + use_fast_accum=fast_accum, + a_dtype=DataType.FP4, + b_dtype=DataType.FP4, + scale_dtype=DataType.E8M0, + ) + + else: + return torch._scaled_mm( + A, B, scale_a, scale_b, out_dtype=d3, use_fast_accum=fast_accum + ) + + # test + # res = do_matmul(A, B) fp8_time_sec, fp8_tops_sec, fp8_pct_top_peak = do_benchmarks( tops, dtype_to_peak_tops[d1], use_gpu_kernel_time, do_matmul, A, B ) print( - f"fp8 time_sec {fp8_time_sec:.2E}, tops/sec {fp8_tops_sec:.2E}, pct_peak {fp8_pct_top_peak:.3f}" + f"lowp time_sec {fp8_time_sec:.2E}, tops/sec {fp8_tops_sec:.2E}, pct_peak {fp8_pct_top_peak:.3f}" ) del A, B, scale_a, scale_b diff --git a/test/prototype/mx_formats/test_mx_linear.py b/test/prototype/mx_formats/test_mx_linear.py index 35afeb7959..68da09ecc6 100644 --- a/test/prototype/mx_formats/test_mx_linear.py +++ b/test/prototype/mx_formats/test_mx_linear.py @@ -5,6 +5,7 @@ # LICENSE file in the root directory of this source tree. import copy +from enum import IntEnum import pytest import torch @@ -30,6 +31,14 @@ pytest.skip("Unsupported PyTorch version", allow_module_level=True) +# not for land, https://www.internalfb.com/phabricator/paste/view/P1717686991 +class DataType(IntEnum): + DEFAULT = 0 + E8M0 = 1 + FP4 = 2 + UFP8 = 3 + + # source: https://stackoverflow.com/a/22638709 @pytest.fixture(autouse=True) def run_around_tests(): @@ -234,3 +243,87 @@ def test_filter_fn(): swap_linear_with_mx_inference_linear(m2, torch.float8_e4m3fn, 32, filter_fn) # noqa: E501 assert type(m2[0]) == MXInferenceLinear assert type(m2[1]) == torch.nn.Linear + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +@pytest.mark.skipif( + not is_sm_at_least_100(), reason="blockwise torch._scaled_mm requires CUDA 10.0 or higher" +) +def test_scaled_mm_mxfp8(): + # hello world + # next: basic numerics + + M, K, N = 8192, 4096, 8192 + BLOCK_SIZE = 32 + a = torch.randint(128, (M, K), device="cuda", dtype=torch.uint8).view( + torch.float8_e4m3fn + ) + b = ( + torch.randint(128, (N, K), device="cuda", dtype=torch.uint8) + .view(torch.float8_e4m3fn) + .t() + ) + a_scales = torch.randint( + 128, ((M * K) // BLOCK_SIZE,), device="cuda", dtype=torch.uint8 + ).view(M, K // BLOCK_SIZE) + b_scales = ( + torch.randint(128, ((K * N) // BLOCK_SIZE,), device="cuda", dtype=torch.uint8) + .view(N, K // BLOCK_SIZE) + .t() + ) + out = torch._scaled_mm( + a, + b, + a_scales, + b_scales, + None, + None, + torch.bfloat16, + False, + None, + None, + DataType.E8M0, + ) + print(out) + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +@pytest.mark.skipif( + not is_sm_at_least_100(), reason="blockwise torch._scaled_mm requires CUDA 10.0 or higher" +) +def test_scaled_mm_nvfp4(): + # hello world + # next: basic numerics + + M, K, N = 8192, 4096, 8192 + BLOCK_SIZE = 16 + a = torch.randint(128, ((M * K) // 2,), device="cuda", dtype=torch.uint8).view( + M, K // 2 + ) + b = ( + torch.randint(128, ((K * N) // 2,), device="cuda", dtype=torch.uint8) + .view(N, K // 2) + .t() + ) + a_scales = torch.randint( + 128, ((M * K) // BLOCK_SIZE,), device="cuda", dtype=torch.uint8 + ).view(M, K // BLOCK_SIZE) + b_scales = ( + torch.randint(128, ((K * N) // BLOCK_SIZE,), device="cuda", dtype=torch.uint8) + .view(N, K // BLOCK_SIZE) + .t() + ) + out = torch._scaled_mm( + a, + b, + a_scales, + b_scales, + None, + None, + torch.bfloat16, + False, + DataType.FP4, + DataType.FP4, + DataType.UFP8, + ) + print(out)