From c21d24ce8f2d3a42e02a84a1ffcc0524fc548e8a Mon Sep 17 00:00:00 2001 From: "Peter Y. Yeh" Date: Wed, 16 Apr 2025 15:54:23 -0700 Subject: [PATCH 1/4] Enhance MX formats to support HIPBLASLT kernel choice and update validation logic. Added MXFP8_HIPBLASLT recipe and adjusted mx_mm function to accommodate new kernel options. --- torchao/prototype/mx_formats/config.py | 15 +++++++++++++++ torchao/prototype/mx_formats/mx_ops.py | 23 ++++++++++++++++++----- 2 files changed, 33 insertions(+), 5 deletions(-) diff --git a/torchao/prototype/mx_formats/config.py b/torchao/prototype/mx_formats/config.py index e1599cfad5..6d38362899 100644 --- a/torchao/prototype/mx_formats/config.py +++ b/torchao/prototype/mx_formats/config.py @@ -33,12 +33,16 @@ class MXGemmKernelChoice(Enum): # note: torch.compile does not work yet, see https://github.com/pytorch/pytorch/issues/147873 CUBLAS = "cublas" + # available only on ROCm with HIPBLASLT support + HIPBLASLT = "hipblaslt" + # Pre-made recipes for common configurations class MXLinearRecipeName(Enum): MXFP8_EMULATED = "mxfp8_emulated" MXFP8_CUBLAS = "mxfp8_cublas" MXFP8_CUTLASS = "mxfp8_cutlass" + MXFP8_HIPBLASLT = "mxfp8_hipblaslt" MXFP4_EMULATED = "mxfp4_emulated" MXFP4_CUTLASS = "mxfp4_cutlass" @@ -66,6 +70,15 @@ def _validate_gemm_kernel_choice(gemm_kernel_choice, block_size, elem_dtype): assert ( elem_dtype in valid_dtypes ), f"elem_dtype must be one of {valid_dtypes} to use the CUTLASS MX gemm kernels, got {elem_dtype}" + elif gemm_kernel_choice == MXGemmKernelChoice.HIPBLASLT: + assert ( + block_size == 32 + ), f"block_size must be 32 to use the HIPBLASLT MX gemm kernels, got {block_size}" + valid_dtypes = [torch.float8_e4m3fn] + assert ( + elem_dtype in valid_dtypes + ), f"elem_dtype must be one of {valid_dtypes} to use the HIPBLASLT MX gemm kernels, got {elem_dtype}" + assert torch.version.hip is not None, "HIPBLASLT requires ROCm" @dataclass @@ -128,6 +141,8 @@ def from_recipe_name( return MXLinearConfig(gemm_kernel_choice=MXGemmKernelChoice.CUBLAS) elif recipe_name is MXLinearRecipeName.MXFP8_CUTLASS: return MXLinearConfig(gemm_kernel_choice=MXGemmKernelChoice.CUTLASS) + elif recipe_name is MXLinearRecipeName.MXFP8_HIPBLASLT: + return MXLinearConfig(gemm_kernel_choice=MXGemmKernelChoice.HIPBLASLT) elif recipe_name is MXLinearRecipeName.MXFP4_EMULATED: return MXLinearConfig(elem_dtype=DTYPE_FP4) elif recipe_name is MXLinearRecipeName.MXFP4_CUTLASS: diff --git a/torchao/prototype/mx_formats/mx_ops.py b/torchao/prototype/mx_formats/mx_ops.py index c5d60a33de..f08fa9ee28 100644 --- a/torchao/prototype/mx_formats/mx_ops.py +++ b/torchao/prototype/mx_formats/mx_ops.py @@ -75,8 +75,14 @@ def mx_mm(aten_op, args, kwargs=None): b = args[1] assert isinstance(a, MXTensor) and isinstance(b, MXTensor) assert a._gemm_kernel_choice == b._gemm_kernel_choice, "unsupported" - if a._gemm_kernel_choice in (MXGemmKernelChoice.CUBLAS, MXGemmKernelChoice.CUTLASS): - # real MX gemm backed by torchao's CUTLASS kernels + kernel_choice = a._gemm_kernel_choice + valid_kernels = ( + MXGemmKernelChoice.CUBLAS, + MXGemmKernelChoice.CUTLASS, + MXGemmKernelChoice.HIPBLASLT, + ) + if kernel_choice in valid_kernels: + # real MX gemm backed by torchao's CUTLASS/CUBLAS/HIPBLASLT kernels M, K, N = a.shape[0], a.shape[1], b.shape[1] assert a._data.is_contiguous() assert b._data.t().is_contiguous() @@ -88,7 +94,12 @@ def mx_mm(aten_op, args, kwargs=None): b_scale_block = to_blocked(b_scale) if a._elem_dtype == torch.float8_e4m3fn: assert b._elem_dtype == torch.float8_e4m3fn - if a._gemm_kernel_choice is MXGemmKernelChoice.CUBLAS: + scaled_mm_kernels = ( + MXGemmKernelChoice.CUBLAS, + MXGemmKernelChoice.HIPBLASLT, + ) + if kernel_choice in scaled_mm_kernels: + # Use native scaled_mm for both CUBLAS and HIPBLASLT res = torch._scaled_mm( a._data, b._data, @@ -103,7 +114,8 @@ def mx_mm(aten_op, args, kwargs=None): else: assert a._elem_dtype == DTYPE_FP4 assert b._elem_dtype == DTYPE_FP4 - assert a._gemm_kernel_choice is MXGemmKernelChoice.CUTLASS, "unsupported" + msg = "FP4 is only supported with CUTLASS kernel at this moment" + assert kernel_choice is MXGemmKernelChoice.CUTLASS, msg res = torchao.ops.mx_fp4_bf16( a._data, b._data, a_scale_block, b_scale_block ) @@ -162,7 +174,8 @@ def mx_view_op(aten_op, args, kwargs=None): if args[0]._elem_dtype == DTYPE_FP4: # special case fp4 as we pack two elements per byte new_size = tensor_size_hp_to_fp4x2(new_size, data.is_contiguous()) - elif args[0]._elem_dtype in [DTYPE_FP6_E3M2, DTYPE_FP6_E2M3] and args[0]._pack_fp6: + elif (args[0]._elem_dtype in [DTYPE_FP6_E3M2, DTYPE_FP6_E2M3] and + args[0]._pack_fp6): # special case fp6 as we pack 4 elements in 3 bytes new_size = tensor_size_hpx3_to_fp6x4(new_size, data.is_contiguous()) new_data = aten_op(data, new_size, *args[2:], **kwargs) From 36dd5b7a47884c4a862064dca864ca214db89a7d Mon Sep 17 00:00:00 2001 From: "Peter Y. Yeh" Date: Wed, 16 Apr 2025 16:11:41 -0700 Subject: [PATCH 2/4] Update README.md to include support for AMD MI355x hardware and HIPBLASLT kernel choice for mxfp8 gemm. Enhance documentation on end-to-end performance optimization efforts for AMD GPUs. --- torchao/prototype/mx_formats/README.md | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/torchao/prototype/mx_formats/README.md b/torchao/prototype/mx_formats/README.md index 955a02704f..b9224c2594 100644 --- a/torchao/prototype/mx_formats/README.md +++ b/torchao/prototype/mx_formats/README.md @@ -1,7 +1,7 @@ # MX training and inference with native PyTorch This is a workflow for e2e training and inference with MX dtypes from the [MX OCP spec](https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf) -in native PyTorch. We are currently in prototype and are actively working on optimizing these workflows on the NVIDIA B200 hardware. +in native PyTorch. We are currently in prototype and are actively working on optimizing these workflows on the NVIDIA B200 and AMD MI355x hardware. ## Overall status @@ -29,6 +29,9 @@ from torchao.prototype.mx_formats import MXLinearConfig, MXGemmKernelChoice gemm_kernel_choice = MXGemmKernelChoice.CUBLAS # gemm_kernel_choice = MXGemmKernelChoice.CUTLASS +# on AMD MI355x GPUs with ROCm 6.5+ and gfx950, you can use HIPBLASLT mxfp8 kernels +# gemm_kernel_choice = MXGemmKernelChoice.HIPBLASLT + # on older NVIDIA gpus, you can run training with emulated MX gemm # gemm_kernel_choice = MXGemmKernelChoice.EMULATED @@ -97,6 +100,8 @@ on supported hardware, you can run the following command: // example output: https://gist.github.com/vkuzo/a1ddb782e6e1c2aef0c726b3df99efbc ``` +On AMD MI355x GPUs with ROCm 6.5+ and gfx950, we use HIPBLASLT for mxfp8 gemm. We are actively working on optimizing the end-to-end performance for AMD hardware. + ## to_mx cast across dim0 and dim1 On NVIDIA B200 machines, our to_mx kernels for mxfp8 achieve **up to 5.5 TB/s** for the dim0 cast (with torch.compile), From c75df8e07d380673c1c75b300f3b294ef1452b94 Mon Sep 17 00:00:00 2001 From: "Peter Y. Yeh" Date: Fri, 18 Apr 2025 09:41:24 -0700 Subject: [PATCH 3/4] lint --- torchao/prototype/mx_formats/mx_ops.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/torchao/prototype/mx_formats/mx_ops.py b/torchao/prototype/mx_formats/mx_ops.py index f08fa9ee28..cca18f3b89 100644 --- a/torchao/prototype/mx_formats/mx_ops.py +++ b/torchao/prototype/mx_formats/mx_ops.py @@ -174,8 +174,7 @@ def mx_view_op(aten_op, args, kwargs=None): if args[0]._elem_dtype == DTYPE_FP4: # special case fp4 as we pack two elements per byte new_size = tensor_size_hp_to_fp4x2(new_size, data.is_contiguous()) - elif (args[0]._elem_dtype in [DTYPE_FP6_E3M2, DTYPE_FP6_E2M3] and - args[0]._pack_fp6): + elif args[0]._elem_dtype in [DTYPE_FP6_E3M2, DTYPE_FP6_E2M3] and args[0]._pack_fp6: # special case fp6 as we pack 4 elements in 3 bytes new_size = tensor_size_hpx3_to_fp6x4(new_size, data.is_contiguous()) new_data = aten_op(data, new_size, *args[2:], **kwargs) @@ -198,9 +197,9 @@ def autocast_to_copy(aten_op, args, kwargs=None): tensor. """ assert isinstance(args[0], MXTensor) - assert ( - len(kwargs) == 1 and "dtype" in kwargs - ), "Only support dtype kwarg for autocast" + assert len(kwargs) == 1 and "dtype" in kwargs, ( + "Only support dtype kwarg for autocast" + ) assert kwargs["dtype"] in { torch.float16, torch.bfloat16, From df2c2203eae3af02efb4963951284471af4ace0c Mon Sep 17 00:00:00 2001 From: "Peter Y. Yeh" Date: Mon, 5 May 2025 14:36:58 -0700 Subject: [PATCH 4/4] lint --- torchao/prototype/mx_formats/config.py | 18 +++++++++--------- torchao/prototype/mx_formats/mx_ops.py | 2 +- 2 files changed, 10 insertions(+), 10 deletions(-) diff --git a/torchao/prototype/mx_formats/config.py b/torchao/prototype/mx_formats/config.py index 3df554e1da..f41aab817a 100644 --- a/torchao/prototype/mx_formats/config.py +++ b/torchao/prototype/mx_formats/config.py @@ -67,17 +67,17 @@ def _validate_gemm_kernel_choice(gemm_kernel_choice, block_size, elem_dtype): f"block_size must be 32 to use the cuBLAS MX gemm kernels, got {block_size}" ) valid_dtypes = [torch.float8_e4m3fn] - assert ( - elem_dtype in valid_dtypes - ), f"elem_dtype must be one of {valid_dtypes} to use the CUTLASS MX gemm kernels, got {elem_dtype}" + assert elem_dtype in valid_dtypes, ( + f"elem_dtype must be one of {valid_dtypes} to use the CUTLASS MX gemm kernels, got {elem_dtype}" + ) elif gemm_kernel_choice == MXGemmKernelChoice.HIPBLASLT: - assert ( - block_size == 32 - ), f"block_size must be 32 to use the HIPBLASLT MX gemm kernels, got {block_size}" + assert block_size == 32, ( + f"block_size must be 32 to use the HIPBLASLT MX gemm kernels, got {block_size}" + ) valid_dtypes = [torch.float8_e4m3fn] - assert ( - elem_dtype in valid_dtypes - ), f"elem_dtype must be one of {valid_dtypes} to use the HIPBLASLT MX gemm kernels, got {elem_dtype}" + assert elem_dtype in valid_dtypes, ( + f"elem_dtype must be one of {valid_dtypes} to use the HIPBLASLT MX gemm kernels, got {elem_dtype}" + ) assert torch.version.hip is not None, "HIPBLASLT requires ROCm" diff --git a/torchao/prototype/mx_formats/mx_ops.py b/torchao/prototype/mx_formats/mx_ops.py index 8795364da2..c510dc0c59 100644 --- a/torchao/prototype/mx_formats/mx_ops.py +++ b/torchao/prototype/mx_formats/mx_ops.py @@ -99,7 +99,7 @@ def mx_mm(aten_op, args, kwargs=None): MXGemmKernelChoice.CUBLAS, MXGemmKernelChoice.HIPBLASLT, ) - + assert a._gemm_kernel_choice is scaled_mm_kernels, ( "CUBLAS/HIPBLASLT is the only supported kernel choice for MX FP8 operations atm" )