diff --git a/compiler/src/iree/compiler/Codegen/Common/BUILD.bazel b/compiler/src/iree/compiler/Codegen/Common/BUILD.bazel index 2a44e1d924a7..f53f6160c594 100644 --- a/compiler/src/iree/compiler/Codegen/Common/BUILD.bazel +++ b/compiler/src/iree/compiler/Codegen/Common/BUILD.bazel @@ -238,6 +238,7 @@ iree_compiler_cc_library( "@llvm-project//mlir:BufferizationTransforms", "@llvm-project//mlir:DialectUtils", "@llvm-project//mlir:FuncDialect", + "@llvm-project//mlir:FuncTransforms", "@llvm-project//mlir:GPUDialect", "@llvm-project//mlir:IR", "@llvm-project//mlir:LLVMCommonConversion", diff --git a/compiler/src/iree/compiler/Codegen/Common/CMakeLists.txt b/compiler/src/iree/compiler/Codegen/Common/CMakeLists.txt index eb519d24cf4b..f047fe1f5015 100644 --- a/compiler/src/iree/compiler/Codegen/Common/CMakeLists.txt +++ b/compiler/src/iree/compiler/Codegen/Common/CMakeLists.txt @@ -188,6 +188,7 @@ iree_cc_library( MLIRBufferizationDialect MLIRBufferizationTransforms MLIRFuncDialect + MLIRFuncTransforms MLIRGPUDialect MLIRIR MLIRLLVMCommonConversion diff --git a/compiler/src/iree/compiler/Codegen/Common/ConvertBf16ToUInt16Buffers.cpp b/compiler/src/iree/compiler/Codegen/Common/ConvertBf16ToUInt16Buffers.cpp index 1df9f5342f8b..0fed4629b0ca 100644 --- a/compiler/src/iree/compiler/Codegen/Common/ConvertBf16ToUInt16Buffers.cpp +++ b/compiler/src/iree/compiler/Codegen/Common/ConvertBf16ToUInt16Buffers.cpp @@ -22,6 +22,7 @@ #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Arith/Transforms/Passes.h" #include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/Func/Transforms/FuncConversions.h" #include "mlir/Dialect/MemRef/Transforms/Transforms.h" #include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/Dialect/Vector/IR/VectorOps.h" @@ -54,6 +55,18 @@ class Bf16EmulationConverter : public TypeConverter { addConversion([this](ShapedType ty) -> std::optional { return ty.clone(convertType(ty.getElementType())); }); + + addConversion([this](FunctionType ty) -> std::optional { + SmallVector inputs; + if (failed(convertTypes(ty.getInputs(), inputs))) + return std::nullopt; + + SmallVector results; + if (failed(convertTypes(ty.getResults(), results))) + return std::nullopt; + + return FunctionType::get(ty.getContext(), inputs, results); + }); } }; @@ -217,6 +230,10 @@ std::optional materializeArithBitcast(OpBuilder &builder, Type resultTy, static void populateIreeBf16EmulationPatterns(RewritePatternSet &patterns, TypeConverter &typeConverter) { + populateFunctionOpInterfaceTypeConversionPattern(patterns, + typeConverter); + populateCallOpTypeConversionPattern(patterns, typeConverter); + populateReturnOpTypeConversionPattern(patterns, typeConverter); patterns.add( typeConverter, patterns.getContext()); @@ -244,7 +261,6 @@ struct ConvertBf16ToUInt16BuffersPass final // Run the main emulation pass. { ConversionTarget target(*ctx); - target.addLegalOp(); target.addDynamicallyLegalOp([&typeConverter]( Operation *op) { return typeConverter.isLegal(cast(op).getFunctionType()); diff --git a/compiler/src/iree/compiler/Codegen/Common/test/convert_bf16_to_uint16_buffers.mlir b/compiler/src/iree/compiler/Codegen/Common/test/convert_bf16_to_uint16_buffers.mlir index 9df3a757c274..8b640bb04dbc 100644 --- a/compiler/src/iree/compiler/Codegen/Common/test/convert_bf16_to_uint16_buffers.mlir +++ b/compiler/src/iree/compiler/Codegen/Common/test/convert_bf16_to_uint16_buffers.mlir @@ -34,8 +34,52 @@ func.func @bf16_conversion() { // CHECK-LABEL: @bf16_constant func.func @bf16_constant(%arg0 : bf16) -> bf16 { // CHECK: %[[CNST:.+]] = arith.constant 16256 : i16 - // CHECK: %[[CAST:.+]] = arith.bitcast %[[CNST]] %c0 = arith.constant 1.0 : bf16 - // CHECK: return %[[CAST]] + // CHECK: return %[[CNST]] return %c0 : bf16 } + +// ----- + +// CHECK-LABEL: @iree_uk_mmt4d +// CHECK-SAME: memref +// CHECK-SAME: memref +// CHECK-SAME: memref +func.func private @iree_uk_mmt4d(memref, index, index, memref, index, index, memref, index, index, index, index, index, i32, i32, i32, i32) attributes {hal.import.bitcode = true, hal.import.cconv = 1 : i32, hal.import.fields = ["processor_data"], llvm.bareptr = true} + +// CHECK-LABEL: @mmt4d_bf16xbf16xf32 +// CHECK: func.call +// CHECK-SAME: memref +// CHECK-SAME: memref +// CHECK-SAME: memref +func.func @mmt4d_bf16xbf16xf32() { + %c32 = arith.constant 32 : index + %c24 = arith.constant 24 : index + %c3 = arith.constant 3 : index + %c8_i32 = arith.constant 8 : i32 + %c1_i32 = arith.constant 1 : i32 + %c1029_i32 = arith.constant 1029 : i32 + %c1 = arith.constant 1 : index + %c0 = arith.constant 0 : index + %c64 = arith.constant 64 : index + %c128 = arith.constant 128 : index + %0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : memref<1x3x8x1xbf16> + memref.assume_alignment %0, 64 : memref<1x3x8x1xbf16> + %1 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c64) flags(ReadOnly) : memref<1x3x8x1xbf16, strided<[24, 8, 1, 1], offset: 32>> + memref.assume_alignment %1, 64 : memref<1x3x8x1xbf16, strided<[24, 8, 1, 1], offset: 32>> + %2 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%c128) : memref<1x1x8x8xf32, strided<[64, 64, 8, 1], offset: 32>> + memref.assume_alignment %2, 64 : memref<1x1x8x8xf32, strided<[64, 64, 8, 1], offset: 32>> + %workgroup_id_x = hal.interface.workgroup.id[0] : index + %workgroup_count_x = hal.interface.workgroup.count[0] : index + %workgroup_id_y = hal.interface.workgroup.id[1] : index + %workgroup_count_y = hal.interface.workgroup.count[1] : index + scf.for %arg0 = %workgroup_id_y to %c1 step %workgroup_count_y { + scf.for %arg1 = %workgroup_id_x to %c1 step %workgroup_count_x { + %base_buffer, %offset, %sizes:4, %strides:4 = memref.extract_strided_metadata %0 : memref<1x3x8x1xbf16> -> memref, index, index, index, index, index, index, index, index, index + %base_buffer_0, %offset_1, %sizes_2:4, %strides_3:4 = memref.extract_strided_metadata %1 : memref<1x3x8x1xbf16, strided<[24, 8, 1, 1], offset: 32>> -> memref, index, index, index, index, index, index, index, index, index + %base_buffer_4, %offset_5, %sizes_6:4, %strides_7:4 = memref.extract_strided_metadata %2 : memref<1x1x8x8xf32, strided<[64, 64, 8, 1], offset: 32>> -> memref, index, index, index, index, index, index, index, index, index + func.call @iree_uk_mmt4d(%base_buffer, %c0, %c24, %base_buffer_0, %c32, %c24, %base_buffer_4, %c32, %c64, %c1, %c1, %c3, %c8_i32, %c8_i32, %c1_i32, %c1029_i32) : (memref, index, index, memref, index, index, memref, index, index, index, index, index, i32, i32, i32, i32) -> () + } + } + return +}