diff --git a/torchao/prototype/mx_formats/README.md b/torchao/prototype/mx_formats/README.md index 955a02704f..b9224c2594 100644 --- a/torchao/prototype/mx_formats/README.md +++ b/torchao/prototype/mx_formats/README.md @@ -1,7 +1,7 @@ # MX training and inference with native PyTorch This is a workflow for e2e training and inference with MX dtypes from the [MX OCP spec](https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf) -in native PyTorch. We are currently in prototype and are actively working on optimizing these workflows on the NVIDIA B200 hardware. +in native PyTorch. We are currently in prototype and are actively working on optimizing these workflows on the NVIDIA B200 and AMD MI355x hardware. ## Overall status @@ -29,6 +29,9 @@ from torchao.prototype.mx_formats import MXLinearConfig, MXGemmKernelChoice gemm_kernel_choice = MXGemmKernelChoice.CUBLAS # gemm_kernel_choice = MXGemmKernelChoice.CUTLASS +# on AMD MI355x GPUs with ROCm 6.5+ and gfx950, you can use HIPBLASLT mxfp8 kernels +# gemm_kernel_choice = MXGemmKernelChoice.HIPBLASLT + # on older NVIDIA gpus, you can run training with emulated MX gemm # gemm_kernel_choice = MXGemmKernelChoice.EMULATED @@ -97,6 +100,8 @@ on supported hardware, you can run the following command: // example output: https://gist.github.com/vkuzo/a1ddb782e6e1c2aef0c726b3df99efbc ``` +On AMD MI355x GPUs with ROCm 6.5+ and gfx950, we use HIPBLASLT for mxfp8 gemm. We are actively working on optimizing the end-to-end performance for AMD hardware. + ## to_mx cast across dim0 and dim1 On NVIDIA B200 machines, our to_mx kernels for mxfp8 achieve **up to 5.5 TB/s** for the dim0 cast (with torch.compile), diff --git a/torchao/prototype/mx_formats/config.py b/torchao/prototype/mx_formats/config.py index c49e1595a8..f41aab817a 100644 --- a/torchao/prototype/mx_formats/config.py +++ b/torchao/prototype/mx_formats/config.py @@ -33,11 +33,16 @@ class MXGemmKernelChoice(Enum): # note: torch.compile does not work yet, see https://github.com/pytorch/pytorch/issues/147873 CUBLAS = "cublas" + # available only on ROCm with HIPBLASLT support + HIPBLASLT = "hipblaslt" + # Pre-made recipes for common configurations class MXLinearRecipeName(Enum): MXFP8_EMULATED = "mxfp8_emulated" MXFP8_CUBLAS = "mxfp8_cublas" + MXFP8_CUTLASS = "mxfp8_cutlass" + MXFP8_HIPBLASLT = "mxfp8_hipblaslt" MXFP4_EMULATED = "mxfp4_emulated" MXFP4_CUTLASS = "mxfp4_cutlass" @@ -65,6 +70,15 @@ def _validate_gemm_kernel_choice(gemm_kernel_choice, block_size, elem_dtype): assert elem_dtype in valid_dtypes, ( f"elem_dtype must be one of {valid_dtypes} to use the CUTLASS MX gemm kernels, got {elem_dtype}" ) + elif gemm_kernel_choice == MXGemmKernelChoice.HIPBLASLT: + assert block_size == 32, ( + f"block_size must be 32 to use the HIPBLASLT MX gemm kernels, got {block_size}" + ) + valid_dtypes = [torch.float8_e4m3fn] + assert elem_dtype in valid_dtypes, ( + f"elem_dtype must be one of {valid_dtypes} to use the HIPBLASLT MX gemm kernels, got {elem_dtype}" + ) + assert torch.version.hip is not None, "HIPBLASLT requires ROCm" @dataclass @@ -125,6 +139,10 @@ def from_recipe_name( return MXLinearConfig() elif recipe_name is MXLinearRecipeName.MXFP8_CUBLAS: return MXLinearConfig(gemm_kernel_choice=MXGemmKernelChoice.CUBLAS) + elif recipe_name is MXLinearRecipeName.MXFP8_CUTLASS: + return MXLinearConfig(gemm_kernel_choice=MXGemmKernelChoice.CUTLASS) + elif recipe_name is MXLinearRecipeName.MXFP8_HIPBLASLT: + return MXLinearConfig(gemm_kernel_choice=MXGemmKernelChoice.HIPBLASLT) elif recipe_name is MXLinearRecipeName.MXFP4_EMULATED: return MXLinearConfig(elem_dtype=DTYPE_FP4) elif recipe_name is MXLinearRecipeName.MXFP4_CUTLASS: diff --git a/torchao/prototype/mx_formats/mx_ops.py b/torchao/prototype/mx_formats/mx_ops.py index af2d89c112..c510dc0c59 100644 --- a/torchao/prototype/mx_formats/mx_ops.py +++ b/torchao/prototype/mx_formats/mx_ops.py @@ -75,8 +75,14 @@ def mx_mm(aten_op, args, kwargs=None): b = args[1] assert isinstance(a, MXTensor) and isinstance(b, MXTensor) assert a._gemm_kernel_choice == b._gemm_kernel_choice, "unsupported" - if a._gemm_kernel_choice in (MXGemmKernelChoice.CUBLAS, MXGemmKernelChoice.CUTLASS): - # real MX gemm backed by torchao's CUTLASS kernels + kernel_choice = a._gemm_kernel_choice + valid_kernels = ( + MXGemmKernelChoice.CUBLAS, + MXGemmKernelChoice.CUTLASS, + MXGemmKernelChoice.HIPBLASLT, + ) + if kernel_choice in valid_kernels: + # real MX gemm backed by torchao's CUTLASS/CUBLAS/HIPBLASLT kernels M, K, N = a.shape[0], a.shape[1], b.shape[1] assert a._data.is_contiguous() assert b._data.t().is_contiguous() @@ -88,8 +94,14 @@ def mx_mm(aten_op, args, kwargs=None): b_scale_block = to_blocked(b_scale) if a._elem_dtype == torch.float8_e4m3fn: assert b._elem_dtype == torch.float8_e4m3fn - assert a._gemm_kernel_choice is MXGemmKernelChoice.CUBLAS, ( - "CUBLAS is the only supported kernel choice for MX FP8 operations" + + scaled_mm_kernels = ( + MXGemmKernelChoice.CUBLAS, + MXGemmKernelChoice.HIPBLASLT, + ) + + assert a._gemm_kernel_choice is scaled_mm_kernels, ( + "CUBLAS/HIPBLASLT is the only supported kernel choice for MX FP8 operations atm" ) res = torch._scaled_mm( a._data, @@ -98,10 +110,12 @@ def mx_mm(aten_op, args, kwargs=None): b_scale_block.view(torch.float8_e8m0fnu), out_dtype=torch.bfloat16, ) + else: assert a._elem_dtype == DTYPE_FP4 assert b._elem_dtype == DTYPE_FP4 - assert a._gemm_kernel_choice is MXGemmKernelChoice.CUTLASS, "unsupported" + msg = "FP4 is only supported with CUTLASS kernel at this moment" + assert kernel_choice is MXGemmKernelChoice.CUTLASS, msg res = torchao.ops.mx_fp4_bf16( a._data, b._data, a_scale_block, b_scale_block )