Skip to content

Commit

Permalink
Add support for converting bf16 to uint16 on func ops. (iree-org#15231)
Browse files Browse the repository at this point in the history
This enables mmt4d ukernels on bf16xbf16->f32.
  • Loading branch information
hanhanW authored Oct 19, 2023
1 parent 82611a9 commit 8b1af38
Show file tree
Hide file tree
Showing 4 changed files with 65 additions and 3 deletions.
1 change: 1 addition & 0 deletions compiler/src/iree/compiler/Codegen/Common/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -238,6 +238,7 @@ iree_compiler_cc_library(
"@llvm-project//mlir:BufferizationTransforms",
"@llvm-project//mlir:DialectUtils",
"@llvm-project//mlir:FuncDialect",
"@llvm-project//mlir:FuncTransforms",
"@llvm-project//mlir:GPUDialect",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:LLVMCommonConversion",
Expand Down
1 change: 1 addition & 0 deletions compiler/src/iree/compiler/Codegen/Common/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,7 @@ iree_cc_library(
MLIRBufferizationDialect
MLIRBufferizationTransforms
MLIRFuncDialect
MLIRFuncTransforms
MLIRGPUDialect
MLIRIR
MLIRLLVMCommonConversion
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Arith/Transforms/Passes.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/Func/Transforms/FuncConversions.h"
#include "mlir/Dialect/MemRef/Transforms/Transforms.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
Expand Down Expand Up @@ -54,6 +55,18 @@ class Bf16EmulationConverter : public TypeConverter {
addConversion([this](ShapedType ty) -> std::optional<Type> {
return ty.clone(convertType(ty.getElementType()));
});

addConversion([this](FunctionType ty) -> std::optional<Type> {
SmallVector<Type> inputs;
if (failed(convertTypes(ty.getInputs(), inputs)))
return std::nullopt;

SmallVector<Type> results;
if (failed(convertTypes(ty.getResults(), results)))
return std::nullopt;

return FunctionType::get(ty.getContext(), inputs, results);
});
}
};

Expand Down Expand Up @@ -217,6 +230,10 @@ std::optional<Value> materializeArithBitcast(OpBuilder &builder, Type resultTy,

static void populateIreeBf16EmulationPatterns(RewritePatternSet &patterns,
TypeConverter &typeConverter) {
populateFunctionOpInterfaceTypeConversionPattern<func::FuncOp>(patterns,
typeConverter);
populateCallOpTypeConversionPattern(patterns, typeConverter);
populateReturnOpTypeConversionPattern(patterns, typeConverter);
patterns.add<GenericTypeConversionPattern, ConvertHalInterfaceBindingSubspan,
ConvertMemRefAlloc, ConvertMemRefLoad, ConvertMemRefStore>(
typeConverter, patterns.getContext());
Expand Down Expand Up @@ -244,7 +261,6 @@ struct ConvertBf16ToUInt16BuffersPass final
// Run the main emulation pass.
{
ConversionTarget target(*ctx);
target.addLegalOp<func::ReturnOp>();
target.addDynamicallyLegalOp<func::FuncOp>([&typeConverter](
Operation *op) {
return typeConverter.isLegal(cast<func::FuncOp>(op).getFunctionType());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,52 @@ func.func @bf16_conversion() {
// CHECK-LABEL: @bf16_constant
func.func @bf16_constant(%arg0 : bf16) -> bf16 {
// CHECK: %[[CNST:.+]] = arith.constant 16256 : i16
// CHECK: %[[CAST:.+]] = arith.bitcast %[[CNST]]
%c0 = arith.constant 1.0 : bf16
// CHECK: return %[[CAST]]
// CHECK: return %[[CNST]]
return %c0 : bf16
}

// -----

// CHECK-LABEL: @iree_uk_mmt4d
// CHECK-SAME: memref<i16>
// CHECK-SAME: memref<i16>
// CHECK-SAME: memref<f32>
func.func private @iree_uk_mmt4d(memref<bf16>, index, index, memref<bf16>, index, index, memref<f32>, index, index, index, index, index, i32, i32, i32, i32) attributes {hal.import.bitcode = true, hal.import.cconv = 1 : i32, hal.import.fields = ["processor_data"], llvm.bareptr = true}

// CHECK-LABEL: @mmt4d_bf16xbf16xf32
// CHECK: func.call
// CHECK-SAME: memref<i16>
// CHECK-SAME: memref<i16>
// CHECK-SAME: memref<f32>
func.func @mmt4d_bf16xbf16xf32() {
%c32 = arith.constant 32 : index
%c24 = arith.constant 24 : index
%c3 = arith.constant 3 : index
%c8_i32 = arith.constant 8 : i32
%c1_i32 = arith.constant 1 : i32
%c1029_i32 = arith.constant 1029 : i32
%c1 = arith.constant 1 : index
%c0 = arith.constant 0 : index
%c64 = arith.constant 64 : index
%c128 = arith.constant 128 : index
%0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : memref<1x3x8x1xbf16>
memref.assume_alignment %0, 64 : memref<1x3x8x1xbf16>
%1 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c64) flags(ReadOnly) : memref<1x3x8x1xbf16, strided<[24, 8, 1, 1], offset: 32>>
memref.assume_alignment %1, 64 : memref<1x3x8x1xbf16, strided<[24, 8, 1, 1], offset: 32>>
%2 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%c128) : memref<1x1x8x8xf32, strided<[64, 64, 8, 1], offset: 32>>
memref.assume_alignment %2, 64 : memref<1x1x8x8xf32, strided<[64, 64, 8, 1], offset: 32>>
%workgroup_id_x = hal.interface.workgroup.id[0] : index
%workgroup_count_x = hal.interface.workgroup.count[0] : index
%workgroup_id_y = hal.interface.workgroup.id[1] : index
%workgroup_count_y = hal.interface.workgroup.count[1] : index
scf.for %arg0 = %workgroup_id_y to %c1 step %workgroup_count_y {
scf.for %arg1 = %workgroup_id_x to %c1 step %workgroup_count_x {
%base_buffer, %offset, %sizes:4, %strides:4 = memref.extract_strided_metadata %0 : memref<1x3x8x1xbf16> -> memref<bf16>, index, index, index, index, index, index, index, index, index
%base_buffer_0, %offset_1, %sizes_2:4, %strides_3:4 = memref.extract_strided_metadata %1 : memref<1x3x8x1xbf16, strided<[24, 8, 1, 1], offset: 32>> -> memref<bf16>, index, index, index, index, index, index, index, index, index
%base_buffer_4, %offset_5, %sizes_6:4, %strides_7:4 = memref.extract_strided_metadata %2 : memref<1x1x8x8xf32, strided<[64, 64, 8, 1], offset: 32>> -> memref<f32>, index, index, index, index, index, index, index, index, index
func.call @iree_uk_mmt4d(%base_buffer, %c0, %c24, %base_buffer_0, %c32, %c24, %base_buffer_4, %c32, %c64, %c1, %c1, %c3, %c8_i32, %c8_i32, %c1_i32, %c1029_i32) : (memref<bf16>, index, index, memref<bf16>, index, index, memref<f32>, index, index, index, index, index, i32, i32, i32, i32) -> ()
}
}
return
}

0 comments on commit 8b1af38

Please sign in to comment.