Skip to content

Commit

Permalink
bf16: select appropriate tile sizes on x86 and Arm, and enable in x86…
Browse files Browse the repository at this point in the history
… bitcode build (iree-org#15244)

This PR fixes 2 issues uncovered by e2e testing of bf16 matmuls, iree-org#15243.
I noticed that the optimized ukernel code paths weren't exercised as
they should be. There were 2 separate issues, both fixed by this PR:

1. We weren't picking the right tile size in MaterializeEncoding.
2. On x86, we weren't defining `IREE_UK_BUILD_X86_64_AVX512_BF16` in the
bitcode build.
  • Loading branch information
bjacob authored Oct 20, 2023
1 parent 02e34b0 commit 5a20dce
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,12 @@ chooseMatmulTileParamsAArch64(EncodingUser user, TypeRange elementTypes,
Type out = elementTypes[2];

if (out.isF32() || out.isF16() || out.isBF16()) {
if (lhs.isBF16() && rhs.isBF16() && (out.isBF16() || out.isF32())) {
if (hasFeature(target, "+bf16")) {
// Aim to use BFMMLA.
return MatmulTileParams{8, 4, 8};
}
}
// Note: 16-bit floating point types currently use the same tile size as
// f32. This makes sense when either (1) the accumulator is f32, or (2)
// the arithmetic will have to expand f16 to f32 in registers. We may
Expand Down Expand Up @@ -94,6 +100,11 @@ chooseMatmulTileParamsX86_64(EncodingUser user, TypeRange elementTypes,
Type out = elementTypes[2];

if (out.isF32() || out.isF16() || out.isBF16()) {
if (lhs.isBF16() && rhs.isBF16() && (out.isBF16() || out.isF32())) {
if (hasFeature(target, "+avx512bf16")) {
return MatmulTileParams{16, 2, 16};
}
}
// Note: 16-bit floating point types currently use the same tile size as
// f32. This makes sense when either (1) the accumulator is f32, or (2)
// the arithmetic will have to expand f16 to f32 in registers. We may
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
#define IREE_UK_BUILD_X86_64_AVX2_FMA
#define IREE_UK_BUILD_X86_64_AVX512_BASE
#define IREE_UK_BUILD_X86_64_AVX512_VNNI
#define IREE_UK_BUILD_X86_64_AVX512_BF16
#else // IREE_DEVICE_STANDALONE
// Compiling with the system toolchain. Include the configured header.
#include "iree/builtins/ukernel/arch/x86_64/config_x86_64.h"
Expand Down

0 comments on commit 5a20dce

Please sign in to comment.