diff --git a/benchmarks/float8/bench_matmul.py b/benchmarks/float8/bench_matmul.py index 3d48853754..bcf30ef5fc 100644 --- a/benchmarks/float8/bench_matmul.py +++ b/benchmarks/float8/bench_matmul.py @@ -4,6 +4,7 @@ # This source code is licensed under the BSD 3-Clause license found in the # LICENSE file in the root directory of this source tree. import itertools +from enum import IntEnum from typing import Optional import fire @@ -26,7 +27,16 @@ h100_peak_flops_fp16_tc = 989e12 h100_peak_tops_float8_tc = 1979e12 -dtype_to_peak_tops = { +# HGX B20 specs: https://www.nvidia.com/en-us/data-center/hgx/ +# note: divided numbers from ^ by 2 to undo the effects of sparsity +# TODO(this PR): I'm achieving 5% of peak TFLOPS with bf16 and float8, +# something seems funky +b200_peak_flops_float32 = 600e12 +b200_peak_flops_fp16_tc = 18e15 +b200_peak_tops_float8_tc = 36e15 +b200_peak_tops_float4_tc = 72e15 + +dtype_to_peak_tops_h100 = { torch.float32: h100_peak_flops_float32, torch.float16: h100_peak_flops_fp16_tc, torch.bfloat16: h100_peak_flops_fp16_tc, @@ -34,6 +44,27 @@ torch.float8_e5m2: h100_peak_tops_float8_tc, } +dtype_to_peak_tops_b200 = { + torch.float32: b200_peak_flops_float32, + torch.float16: b200_peak_flops_fp16_tc, + torch.bfloat16: b200_peak_flops_fp16_tc, + torch.float8_e4m3fn: b200_peak_tops_float8_tc, + torch.float8_e5m2: b200_peak_tops_float8_tc, + # TODO float4 +} + +# TODO(this PR): switch automatically by detected hardware type +# TODO(this PR): fp4 is currently using fp8's peak tops below, fix it +dtype_to_peak_tops = dtype_to_peak_tops_b200 + + +# not for land, matching https://www.internalfb.com/phabricator/paste/view/P1717686991 +class DataType(IntEnum): + DEFAULT = 0 + E8M0 = 1 + FP4 = 2 + UFP8 = 3 + def benchmark_fn_in_sec(f, *args, **kwargs): # Manual warmup @@ -75,6 +106,7 @@ def run( N: Optional[int] = None, use_gpu_kernel_time: bool = False, scaling_granularity: str = "tensorwise", + blockwise_dtype: Optional[str] = None, ): device = "cuda" @@ -85,15 +117,17 @@ def run( "K", "N", "ref_time_s", - "fp8_time_s", - "fp8_speedup", + "lowp_time_s", + "lowp_speedup", ) results = [] dtype = torch.bfloat16 name_to_shapes = get_name_to_shapes_iter(shape_gen_name, M, K, N) fast_accum_vals = [True, False] - scaling_granularity = ScalingGranularity(scaling_granularity) + # Note: blockwise not in enum because blockwise is in prototype + if scaling_granularity != "blockwise": + scaling_granularity = ScalingGranularity(scaling_granularity) for idx, (fast_accum, (name, (M, K, N))) in enumerate( itertools.product(fast_accum_vals, name_to_shapes) @@ -119,28 +153,97 @@ def run( # raw float8 matmul (upper bound for what we can achive in eager mode) # TODO(future): add e5m2 d1, d2, d3 = torch.float8_e4m3fn, torch.float8_e4m3fn, dtype - A = torch.zeros(M, K, device=device, dtype=d1) - B = torch.zeros(K, N, device=device, dtype=d2).t().contiguous().t() + A = torch.randn(M, K, device=device).to(d1) + B = torch.randn(K, N, device=device).to(d2).t().contiguous().t() if scaling_granularity == ScalingGranularity.TENSORWISE: scale_a = torch.tensor([1.0], device=device) scale_b = torch.tensor([1.0], device=device) - else: - assert scaling_granularity == ScalingGranularity.AXISWISE, "unsupported" + elif scaling_granularity == ScalingGranularity.AXISWISE: scale_a = torch.ones(M, 1, device=device) scale_b = torch.ones(1, N, device=device) + elif scaling_granularity == "blockwise" and blockwise_dtype == "float8_e4m3": + # TODO(this PR): also block size 16 + BLOCK_SIZE = 32 + A = torch.randint(128, (M, K), device=device, dtype=torch.uint8).view( + torch.float8_e4m3fn + ) + B = ( + torch.randint(128, (N, K), device=device, dtype=torch.uint8) + .view(torch.float8_e4m3fn) + .t() + ) + scale_a = torch.randint( + 128, (M, K // BLOCK_SIZE), dtype=torch.uint8, device="cuda" + ) + scale_b = torch.randint( + 128, (N, K // BLOCK_SIZE), dtype=torch.uint8, device="cuda" + ).t() + elif scaling_granularity == "blockwise" and blockwise_dtype == "float4": + # TODO(this PR): also block size 16 + BLOCK_SIZE = 16 + A = torch.randint(128, (M, K // 2), device=device, dtype=torch.uint8).view( + torch.float8_e4m3fn + ) + B = ( + torch.randint(128, (N, K // 2), device=device, dtype=torch.uint8) + .view(torch.float8_e4m3fn) + .t() + ) + scale_a = torch.randint( + 128, (M, K // BLOCK_SIZE), dtype=torch.uint8, device="cuda" + ) + scale_b = torch.randint( + 128, (N, K // BLOCK_SIZE), dtype=torch.uint8, device="cuda" + ).t() + else: + raise AssertionError(f"unsupported granularity {scaling_granularity}") def do_matmul(A, B): nonlocal scale_a nonlocal scale_b - return torch._scaled_mm( - A, B, scale_a, scale_b, out_dtype=d3, use_fast_accum=fast_accum - ) + + if scaling_granularity == "blockwise" and blockwise_dtype == "float8_e4m3": + return torch._scaled_mm( + A, + B, + scale_a, + scale_b, + bias=None, + scale_result=None, + out_dtype=d3, + use_fast_accum=fast_accum, + a_dtype=None, # inferred from A + b_dtype=None, # inferred from B + scale_dtype=DataType.E8M0, + ) + elif scaling_granularity == "blockwise" and blockwise_dtype == "float4": + return torch._scaled_mm( + A, + B, + scale_a, + scale_b, + bias=None, + scale_result=None, + out_dtype=d3, + use_fast_accum=fast_accum, + a_dtype=DataType.FP4, + b_dtype=DataType.FP4, + scale_dtype=DataType.E8M0, + ) + + else: + return torch._scaled_mm( + A, B, scale_a, scale_b, out_dtype=d3, use_fast_accum=fast_accum + ) + + # test + # res = do_matmul(A, B) fp8_time_sec, fp8_tops_sec, fp8_pct_top_peak = do_benchmarks( tops, dtype_to_peak_tops[d1], use_gpu_kernel_time, do_matmul, A, B ) print( - f"fp8 time_sec {fp8_time_sec:.2E}, tops/sec {fp8_tops_sec:.2E}, pct_peak {fp8_pct_top_peak:.3f}" + f"lowp time_sec {fp8_time_sec:.2E}, tops/sec {fp8_tops_sec:.2E}, pct_peak {fp8_pct_top_peak:.3f}" ) del A, B, scale_a, scale_b diff --git a/test/prototype/mx_formats/test_mx_linear.py b/test/prototype/mx_formats/test_mx_linear.py index 35afeb7959..bee5454311 100644 --- a/test/prototype/mx_formats/test_mx_linear.py +++ b/test/prototype/mx_formats/test_mx_linear.py @@ -5,6 +5,7 @@ # LICENSE file in the root directory of this source tree. import copy +from enum import IntEnum import pytest import torch @@ -17,6 +18,7 @@ swap_linear_with_mx_inference_linear, swap_linear_with_mx_linear, ) +from torchao.prototype.mx_formats.mx_tensor import MXTensor from torchao.quantization.utils import compute_error from torchao.utils import ( TORCH_VERSION_AT_LEAST_2_4, @@ -30,6 +32,14 @@ pytest.skip("Unsupported PyTorch version", allow_module_level=True) +# not for land, https://www.internalfb.com/phabricator/paste/view/P1717686991 +class DataType(IntEnum): + DEFAULT = 0 + E8M0 = 1 + FP4 = 2 + UFP8 = 3 + + # source: https://stackoverflow.com/a/22638709 @pytest.fixture(autouse=True) def run_around_tests(): @@ -234,3 +244,383 @@ def test_filter_fn(): swap_linear_with_mx_inference_linear(m2, torch.float8_e4m3fn, 32, filter_fn) # noqa: E501 assert type(m2[0]) == MXInferenceLinear assert type(m2[1]) == torch.nn.Linear + +# copy-pasted from https://github.com/drisspg/transformer_nuggets/blob/12bf63d334900d57958f839f273f5bca78a8f4a1/transformer_nuggets/mx/to_blocked.py#L54C1-L62C76 +# and modified to return 128x4 instead of 32x16 +def _to_blocked_single(scales: torch.Tensor) -> torch.Tensor: + """Assume that we have a 128x4 block of scales in K Major order + + To see more information on the individual tile layout: + https://docs.nvidia.com/cuda/cublas/index.html#d-block-scaling-factors-layout + """ + assert scales.shape == (128, 4) + scales_tiled = scales.view(4, 32, 4) # view as 4 - (32, 4) tiles + return scales_tiled.transpose(0, 1).reshape(128, 4).contiguous() # Interleave tiles + +def test_to_blocked(): + scales = torch.arange(128 * 4).reshape(128, 4) / 4 + print('orig') + print(scales) + print('blocked') + print(_to_blocked_single(scales)) + # looks right! + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +@pytest.mark.skipif( + not is_sm_at_least_100(), + reason="blockwise torch._scaled_mm requires CUDA 10.0 or higher", +) +def test_scaled_mm_mxfp8_scales_one(): + # basic numerics with all scales 1.0 + # next: other scale values + + # M, K, N = 8192, 4096, 8192 + M, K, N = 128, 128, 128 + BLOCK_SIZE = 32 + a = torch.ones(M, K, device="cuda", dtype=torch.float32).to(torch.float8_e4m3fn) + b = torch.ones(N, K, device="cuda", dtype=torch.float32).to(torch.float8_e4m3fn).t() + + # 127 is 1.0 in e8m0 + scale_val = 127 + + a_scales = torch.full( + (M, K // BLOCK_SIZE), scale_val, device="cuda", dtype=torch.uint8 + ) + b_scales = torch.full( + (K // BLOCK_SIZE, N), scale_val, device="cuda", dtype=torch.uint8 + ).t() + # b_scales[0][0] = 128 + out = torch._scaled_mm( + a, + b, + a_scales, + b_scales, + None, + None, + torch.bfloat16, + False, + None, + None, + DataType.E8M0, + ) + + # [[1, 0, ...], ..., [0, ..., 1]] - correct + torch.set_printoptions(profile="full", linewidth=280) + print(out) + print(torch.max(out)) + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +@pytest.mark.skipif( + not is_sm_at_least_100(), + reason="blockwise torch._scaled_mm requires CUDA 10.0 or higher", +) +def test_scaled_mm_mxfp8_mxtensor(): + # baseline 1: fp32 + # experiment 1: emulated MX from MXTensor + # experiment 2: real MX gemm + + # results so far: + # * experiment 1 is very close to experiment 2 + # * experiments 1 and 2 are far from baseline (lol!) + + # M, K, N = 8192, 4096, 8192 + M, K, N = 128, 128, 128 + BLOCK_SIZE = 32 + a_fp32 = torch.randn(M, K, device="cuda", dtype=torch.float32) + b_fp32 = torch.randn(N, K, device="cuda", dtype=torch.float32).t().contiguous() + + a_mx = MXTensor.to_mx(a_fp32, torch.float8_e4m3fn, BLOCK_SIZE) + b_mx = MXTensor.to_mx(b_fp32, torch.float8_e4m3fn, BLOCK_SIZE).t() + a_s0 = a_mx._scale_e8m0.reshape(M, -1) + a_s1 = _to_blocked_single(a_s0) + b_s0 = b_mx._scale_e8m0.reshape(N, -1) + b_s1 = _to_blocked_single(b_s0) + + # ones_scale = torch.full((M, K // BLOCK_SIZE), 127, dtype=torch.uint8, device="cuda") + + out_ref = a_fp32 @ b_fp32 + print('baseline', out_ref) + + out_mx_emulated = a_mx @ b_mx + print('mx_emulated', out_mx_emulated) + + out_mx_real = torch._scaled_mm( + a_mx._data, + b_mx._data, + # a_scales is really b_scales, and vice versa. Probably switched in cuBLAS kernel? + _to_blocked_single(b_mx._scale_e8m0.reshape(N, -1)), + _to_blocked_single(a_mx._scale_e8m0.reshape(M, -1)), + None, + None, + torch.float32, + False, + None, + None, + DataType.E8M0, + ) + print('mx_real', out_mx_real) + + sqnr_baseline_to_emulated_mx = compute_error(out_ref, out_mx_emulated) + sqnr_baseline_to_real_mx = compute_error(out_ref, out_mx_real) + sqnr_emulated_mx_to_real_mx = compute_error(out_mx_emulated, out_mx_real) + print('sqnr baseline -> emulated_mx', sqnr_baseline_to_emulated_mx) + print('sqnr baseline -> real_mx', sqnr_baseline_to_real_mx) + print('sqnr emulated_mx -> real_mx', sqnr_emulated_mx_to_real_mx) + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +@pytest.mark.skipif( + not is_sm_at_least_100(), + reason="blockwise torch._scaled_mm requires CUDA 10.0 or higher", +) +def test_scaled_mm_mx_reconstruct_scale_a_layout(): + # brute force the expected layout format + # basic numerics with all scales 1.0 + # next: other scale values + + # M, K, N = 8192, 4096, 8192 + M, K, N = 128, 128, 128 + BLOCK_SIZE = 32 + a = torch.ones(M, K, device="cuda", dtype=torch.float32).to(torch.float8_e4m3fn) + + # 127 is 1.0 in e8m0 + scale_val = 127 + + print() + + # Probe torch._scaled_mm to deduce the actual layout used for the scale + # arguments. Specifically, here is what the code below would do if we had + # A and B as 4x4 matrices with MX block size 2. All matrices are shown in float32 + # format, not their actual storage format, to demonstrate the algorithm. + # + # A matrix - set to all-ones + # + # A = 1111 + # 1111 + # 1111 + # 1111 + # + # B matrix variants - all-zeros, except a single one for each mx block in the first column + # + # B_0 = 1000 B_1 = 0000 + # 0000 0000 + # 0000 1000 + # 0000 0000 + # + # A scale - starts as a matrix of all-ones + # + # A_s = 11 + # 11 + # 11 + # 11 + # + # for each row in rows of A: + # for each ol in cols of A: + # initialize A to all-ones + # set A[row][col] = 2.0 + # for each B in [Bs]: + # C = torch._scaled_mm(A, B, A_s, B_s, ...) + # if max(C) > 1.0: + # the scale incremented in A_s was corresponding to the current block + + for scale_row in range(M): + for scale_col in range(K // BLOCK_SIZE): + + a_scales = torch.full( + (M, K // BLOCK_SIZE), scale_val, device="cuda", dtype=torch.uint8 + ) + b_scales = torch.full( + # (K // BLOCK_SIZE, N), scale_val, device="cuda", dtype=torch.uint8 + (N, K // BLOCK_SIZE), scale_val, device="cuda", dtype=torch.uint8 + ) + + # TODO: it looks like blockwise scales are switched in cuBLAS? + # incrementing scale of b looks like it's actually affecting scaling of a + b_scales[scale_row][scale_col] = scale_val + 1 + + # We test every BLOCK_SIZE to deduce which of the blocks is + # responsible for the scale value. Note that this isn't the most + # efficient way to test, but I'm optimizing for dev time here. + for block_idx in range(K // BLOCK_SIZE): + + b = torch.zeros(N, K, device="cuda", dtype=torch.float32) + # set a single one inside the block + b[0][block_idx * BLOCK_SIZE] = 1 + b = b.to(torch.float8_e4m3fn).t() + + out = torch._scaled_mm( + a, + b, + a_scales, + b_scales, + None, + None, + torch.bfloat16, + False, + None, + None, + DataType.E8M0, + ) + + # print(scale_row, scale_col, block_idx) + # torch.set_printoptions(profile="full", linewidth=320) + # print(out) + # print(torch.max(out, keepdim=True)) + + max_val = torch.max(out).item() + if max_val > 1: + max_flat_index = torch.argmax(out).item() + max_row = max_flat_index // M + max_col = max_flat_index % M + assert max_col == 0 + assert max_val == 2.0 + print('scale_coords', scale_row, scale_col, 'block_idx', block_idx, 'max_coords', max_row, max_col, 'max_val', max_val) + + # break + # break + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +@pytest.mark.skipif( + not is_sm_at_least_100(), + reason="blockwise torch._scaled_mm requires CUDA 10.0 or higher", +) +def test_scaled_mm_mx_reconstruct_scale_b_layout(): + # brute force the expected layout format + # basic numerics with all scales 1.0 + # next: other scale values + + # M, K, N = 8192, 4096, 8192 + M, K, N = 128, 128, 128 + BLOCK_SIZE = 32 + b = torch.ones(M, K, device="cuda", dtype=torch.float32).to(torch.float8_e4m3fn).t() + + # 127 is 1.0 in e8m0 + scale_val = 127 + + print() + + # Probe torch._scaled_mm to deduce the actual layout used for the scale + # arguments. Specifically, here is what the code below would do if we had + # A and B as 4x4 matrices with MX block size 2. All matrices are shown in float32 + # format, not their actual storage format, to demonstrate the algorithm. + # + # A matrix variants - all-zeros, except a single one for each mx block in the first row + # + # A_0 = 1000 A_1 = 0010 + # 0000 0000 + # 0000 0000 + # 0000 0000 + # + # B matrix - set to all-ones + # + # B = 1111 + # 1111 + # 1111 + # 1111 + # + # B scale - starts as a matrix of all-ones + # + # B_s = 11 + # 11 + # 11 + # 11 + # + # for each row in rows of B: + # for each col in cols of B: + # initialize B to all-ones + # set B[row][col] = 2.0 + # for each A in [As]: + # C = torch._scaled_mm(A, B, A_s, B_s, ...) + # if max(C) > 1.0: + # the scale incremented in B_s was corresponding to the current block + + for scale_row in range(M): + for scale_col in range(K // BLOCK_SIZE): + + a_scales = torch.full( + (M, K // BLOCK_SIZE), scale_val, device="cuda", dtype=torch.uint8 + ) + b_scales = torch.full( + # (K // BLOCK_SIZE, N), scale_val, device="cuda", dtype=torch.uint8 + (N, K // BLOCK_SIZE), scale_val, device="cuda", dtype=torch.uint8 + ) + + # TODO: it looks like blockwise scales are switched in cuBLAS? + # incrementing scale of a looks like it's actually affecting scaling of b + a_scales[scale_row][scale_col] = scale_val + 1 + + # We test every BLOCK_SIZE to deduce which of the blocks is + # responsible for the scale value. Note that this isn't the most + # efficient way to test, but I'm optimizing for dev time here. + for block_idx in range(K // BLOCK_SIZE): + + a = torch.zeros(M, K, device="cuda", dtype=torch.float32) + # set a single one inside the block + a[0][block_idx * BLOCK_SIZE] = 1 + a = a.to(torch.float8_e4m3fn) + + out = torch._scaled_mm( + a, + b, + a_scales, + b_scales, + None, + None, + torch.bfloat16, + False, + None, + None, + DataType.E8M0, + ) + + # print(scale_row, scale_col, block_idx) + # torch.set_printoptions(profile="full", linewidth=320) + # print(out) + # print(torch.max(out, keepdim=True)) + + max_val = torch.max(out).item() + if max_val > 1: + max_flat_index = torch.argmax(out).item() + max_row = max_flat_index // M + max_col = max_flat_index % M + assert max_row == 0 + assert max_val == 2.0 + print('scale_coords', scale_row, scale_col, 'block_idx', block_idx, 'max_coords', max_row, max_col, 'max_val', max_val) + + # break + # break + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +@pytest.mark.skipif( + not is_sm_at_least_100(), + reason="blockwise torch._scaled_mm requires CUDA 10.0 or higher", +) +def test_scaled_mm_nvfp4(): + # hello world + # next: basic numerics + + M, K, N = 8192, 4096, 8192 + BLOCK_SIZE = 16 + a = torch.randint(128, (M, K // 2), device="cuda", dtype=torch.uint8) + b = torch.randint(128, (N, K // 2), device="cuda", dtype=torch.uint8).t() + a_scales = torch.randint( + 128, (M, K // BLOCK_SIZE), device="cuda", dtype=torch.uint8 + ) + b_scales = torch.randint( + 128, (N, K // BLOCK_SIZE), device="cuda", dtype=torch.uint8 + ).t() + out = torch._scaled_mm( + a, + b, + a_scales, + b_scales, + None, + None, + torch.bfloat16, + False, + DataType.FP4, + DataType.FP4, + DataType.UFP8, + ) + print(out)