From 5a20dce23516316eb57a9f67a1564768ce8fb70c Mon Sep 17 00:00:00 2001 From: bjacob Date: Thu, 19 Oct 2023 22:09:52 -0400 Subject: [PATCH] bf16: select appropriate tile sizes on x86 and Arm, and enable in x86 bitcode build (#15244) This PR fixes 2 issues uncovered by e2e testing of bf16 matmuls, #15243. I noticed that the optimized ukernel code paths weren't exercised as they should be. There were 2 separate issues, both fixed by this PR: 1. We weren't picking the right tile size in MaterializeEncoding. 2. On x86, we weren't defining `IREE_UK_BUILD_X86_64_AVX512_BF16` in the bitcode build. --- .../Codegen/Common/CPU/CPUMaterializeEncodingPass.cpp | 11 +++++++++++ .../ukernel/arch/x86_64/common_x86_64_entry_point.h | 1 + 2 files changed, 12 insertions(+) diff --git a/compiler/src/iree/compiler/Codegen/Common/CPU/CPUMaterializeEncodingPass.cpp b/compiler/src/iree/compiler/Codegen/Common/CPU/CPUMaterializeEncodingPass.cpp index d283a498ec94..bdf7f801fb72 100644 --- a/compiler/src/iree/compiler/Codegen/Common/CPU/CPUMaterializeEncodingPass.cpp +++ b/compiler/src/iree/compiler/Codegen/Common/CPU/CPUMaterializeEncodingPass.cpp @@ -57,6 +57,12 @@ chooseMatmulTileParamsAArch64(EncodingUser user, TypeRange elementTypes, Type out = elementTypes[2]; if (out.isF32() || out.isF16() || out.isBF16()) { + if (lhs.isBF16() && rhs.isBF16() && (out.isBF16() || out.isF32())) { + if (hasFeature(target, "+bf16")) { + // Aim to use BFMMLA. + return MatmulTileParams{8, 4, 8}; + } + } // Note: 16-bit floating point types currently use the same tile size as // f32. This makes sense when either (1) the accumulator is f32, or (2) // the arithmetic will have to expand f16 to f32 in registers. We may @@ -94,6 +100,11 @@ chooseMatmulTileParamsX86_64(EncodingUser user, TypeRange elementTypes, Type out = elementTypes[2]; if (out.isF32() || out.isF16() || out.isBF16()) { + if (lhs.isBF16() && rhs.isBF16() && (out.isBF16() || out.isF32())) { + if (hasFeature(target, "+avx512bf16")) { + return MatmulTileParams{16, 2, 16}; + } + } // Note: 16-bit floating point types currently use the same tile size as // f32. This makes sense when either (1) the accumulator is f32, or (2) // the arithmetic will have to expand f16 to f32 in registers. We may diff --git a/runtime/src/iree/builtins/ukernel/arch/x86_64/common_x86_64_entry_point.h b/runtime/src/iree/builtins/ukernel/arch/x86_64/common_x86_64_entry_point.h index 9720e64c374c..5f5c11c8dcf3 100644 --- a/runtime/src/iree/builtins/ukernel/arch/x86_64/common_x86_64_entry_point.h +++ b/runtime/src/iree/builtins/ukernel/arch/x86_64/common_x86_64_entry_point.h @@ -15,6 +15,7 @@ #define IREE_UK_BUILD_X86_64_AVX2_FMA #define IREE_UK_BUILD_X86_64_AVX512_BASE #define IREE_UK_BUILD_X86_64_AVX512_VNNI +#define IREE_UK_BUILD_X86_64_AVX512_BF16 #else // IREE_DEVICE_STANDALONE // Compiling with the system toolchain. Include the configured header. #include "iree/builtins/ukernel/arch/x86_64/config_x86_64.h"