diff --git a/compiler/src/iree/compiler/Codegen/Dialect/IREECodegenAttrs.td b/compiler/src/iree/compiler/Codegen/Dialect/IREECodegenAttrs.td index 797444637e3e5..529df7a538ad0 100644 --- a/compiler/src/iree/compiler/Codegen/Dialect/IREECodegenAttrs.td +++ b/compiler/src/iree/compiler/Codegen/Dialect/IREECodegenAttrs.td @@ -55,8 +55,6 @@ def SPIRV_WinogradVectorize def VMVX_Default : I32EnumAttrCase<"VMVXDefault", 24>; -def CPU_AccelMatmulExpert - : I32EnumAttrCase<"AccelMatmulExpert", 25>; def Linalg_TransformDialectCodegen : I32EnumAttrCase<"TransformDialectCodegen", 100>; @@ -78,7 +76,7 @@ def DispatchLoweringPassPipelineEnum LLVMGPU_PackUnPack, LLVMGPU_MatmulTensorCoreMmaSync, SPIRV_BaseLowering, SPIRV_BaseDistribute, SPIRV_BaseVectorize, SPIRV_MatmulPromoteVectorize, SPIRV_CooperativeMatrixVectorize, - SPIRV_SubgroupReduce, SPIRV_WinogradVectorize, VMVX_Default, CPU_AccelMatmulExpert, + SPIRV_SubgroupReduce, SPIRV_WinogradVectorize, VMVX_Default, // Transform dialect based codegen Linalg_TransformDialectCodegen, None ]> { diff --git a/compiler/src/iree/compiler/Codegen/LLVMCPU/KernelDispatch.cpp b/compiler/src/iree/compiler/Codegen/LLVMCPU/KernelDispatch.cpp index 1b949f37c2559..b4cb038fefd11 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMCPU/KernelDispatch.cpp +++ b/compiler/src/iree/compiler/Codegen/LLVMCPU/KernelDispatch.cpp @@ -1159,18 +1159,6 @@ static LogicalResult setRootConfig(func::FuncOp entryPointFn, DispatchLoweringPassPipeline::Mmt4dTilingExpert); } -/// Sets the lowering configuration for dispatch region for linalg.matmul root -/// op -static LogicalResult setRootConfig(func::FuncOp entryPointFn, - linalg::MatmulOp matmulOp) { - assert(!getLoweringConfig(matmulOp) && "expected lowering_config is not set"); - SmallVector tileSizes; - tileSizes.push_back(1); - return setOpConfigAndEntryPointFnTranslation( - entryPointFn, matmulOp, tileSizes, - DispatchLoweringPassPipeline::AccelMatmulExpert); -} - /// Sets the lowering configuration for dispatch region for linalg.batch_mmt4d /// root op static LogicalResult setRootConfig(func::FuncOp entryPointFn, @@ -1993,7 +1981,7 @@ setRootConfigImpl(func::FuncOp entryPointFn, Operation *op, targetMLTransInfo); }) .Case( + linalg::Mmt4DOp, linalg::BatchMmt4DOp>( [&](auto op) { return setRootConfig(entryPointFn, op); }) .Case(); - if (!fillOp) - return false; - Value fillVal = fillOp.getDpsInputOperand(0)->get(); - return matchPattern(fillVal, m_Zero()) || - matchPattern(fillVal, m_AnyZeroFloat()); -} - /// Holds a function name and attributes. struct FnNameAndDefAttrs { std::string name; @@ -82,35 +72,87 @@ getFnNameAndDefAttrs(const char *ukernelName, RewriterBase &rewriter, return result; } -/// Matches an (linalg.fill -> )? linalg.matmul operation sequence and converts +/// Matches an (linalg.fill -> )? linalg.mmt4d operation sequence and converts /// it into a iree_codegen.ukernel.generic "accel_matmul_f32" operation, that is later lowered /// into a call to the microkernel. static FailureOr -matchDAGForUKernel(RewriterBase &rewriter, linalg::MatmulOp op) { +matchDAGForUKernel(RewriterBase &rewriter, linalg::Mmt4DOp op) { Value lhs = op.getDpsInputOperand(0)->get(); Value rhs = op.getDpsInputOperand(1)->get(); Value out = op.getDpsInitOperand(0)->get(); + auto lhsType = llvm::cast(lhs.getType()); + auto rhsType = llvm::cast(rhs.getType()); auto outType = llvm::cast(out.getType()); - - // Check if the accumulator is zero-filled. - if (isInitializedToZero(out)) { - // The plugin will not read the existing accumulator, so its defining op can - // be discarded. - if (auto fillOp = out.getDefiningOp()) { - out = fillOp.getDpsInitOperand(0)->get(); - } + Type lhsElemType = lhsType.getElementType(); + Type rhsElemType = rhsType.getElementType(); + Type outElemType = outType.getElementType(); + uint32_t flags = 0; + if (lhsElemType.isSignlessInteger(8) && rhsElemType.isSignlessInteger(8) && + outElemType.isSignlessInteger(32)) { + flags = IREE_UK_FLAG_MMT4D_TYPE_I8I8I32; + } else if (lhsElemType.isF32() && rhsElemType.isF32() && + outElemType.isF32()) { + flags = IREE_UK_FLAG_MMT4D_TYPE_F32F32F32; + } else { + return rewriter.notifyMatchFailure( + op, "unsupported combination of element types"); } + Location loc = op.getLoc(); - Value m = rewriter.create(loc, lhs, 0); - Value n = rewriter.create(loc, rhs, 0); - Value k = rewriter.create(loc, rhs, 1); + + if (outType.getShape()[0] != 1 || outType.getShape()[1] != 1) { + return rewriter.notifyMatchFailure(op, "outer dims need to be 1"); + } + auto outTypeRanked = out.getType().cast(); + RankedTensorType intermediateOutType = + RankedTensorType::Builder(outTypeRanked).dropDim(0); + RankedTensorType reducedOutType = + RankedTensorType::Builder(intermediateOutType).dropDim(0); + Value reducedOut = tensor::createCanonicalRankReducingExtractSliceOp( + rewriter, loc, out, reducedOutType); + + auto lhsTypeRanked = lhs.getType().cast(); + RankedTensorType intermediateLhsType = + RankedTensorType::Builder(lhsTypeRanked).dropDim(0); + RankedTensorType reducedLhsType = + RankedTensorType::Builder(intermediateLhsType).dropDim(0); + auto reducedLhs = tensor::createCanonicalRankReducingExtractSliceOp( + rewriter, loc, lhs, reducedLhsType); + + auto rhsTypeRanked = rhs.getType().cast(); + RankedTensorType intermediateRhsType = + RankedTensorType::Builder(rhsTypeRanked).dropDim(0); + RankedTensorType reducedRhsType = + RankedTensorType::Builder(intermediateRhsType).dropDim(0); + auto reducedRhs = tensor::createCanonicalRankReducingExtractSliceOp( + rewriter, loc, rhs, reducedRhsType); +/* + auto getDimAsI32 = [](RewriterBase &rewriter, Location loc, Value value, + int dim) -> Value { + return rewriter.create( + loc, rewriter.getI32Type(), + rewriter.create(loc, value, dim)); + }; + Value m = getDimAsI32(rewriter, loc, reducedLhs, 0); + Value n = getDimAsI32(rewriter, loc, reducedRhs, 0); + Value k = getDimAsI32(rewriter, loc, reducedRhs, 1); +*/ + Value m = rewriter.create(loc, reducedLhs, 0); + Value n = rewriter.create(loc, reducedRhs, 0); + Value k = rewriter.create(loc, reducedRhs, 1); + Value flagsVal = rewriter.create( + loc, rewriter.getI32IntegerAttr(flags)); auto targetAttr = IREE::HAL::ExecutableTargetAttr::lookup(op); auto fn = getFnNameAndDefAttrs("accel_matmul_f32", rewriter, targetAttr); auto genericMicroKernelOp = rewriter.create( - loc, outType, fn.name, ValueRange{lhs, rhs}, out, ValueRange{m, n, k}, + loc, reducedOutType, fn.name, ValueRange{reducedLhs, reducedRhs}, reducedOut, + ValueRange{m, n, k, flagsVal}, /*fn_def_attrs=*/rewriter.getDictionaryAttr(fn.defAttrs), /*strided_outer_dims=*/rewriter.getIndexAttr(0)); + auto insertSliceOp = tensor::createCanonicalRankReducingInsertSliceOp( + rewriter, loc, genericMicroKernelOp.getResult(0), out); + op.getResults()[0].replaceAllUsesWith(insertSliceOp); return cast( genericMicroKernelOp.getOperation()); } @@ -143,7 +185,7 @@ void LLVMCPULowerToAccelUKernelsPass::runOnOperation() { // Since microkernels are linked as bitcode, they will still undergo LTO-like // optimization in their calling contexts, but we shouldn't expect this to // achieve similar results as fusing structured ops. - patterns.insert>(context); + patterns.insert>(context); if (failed( applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)))) { return signalPassFailure(); diff --git a/compiler/src/iree/compiler/Codegen/LLVMCPU/Passes.cpp b/compiler/src/iree/compiler/Codegen/LLVMCPU/Passes.cpp index 49e60f69028b9..6dab1e15266fc 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMCPU/Passes.cpp +++ b/compiler/src/iree/compiler/Codegen/LLVMCPU/Passes.cpp @@ -57,7 +57,7 @@ static llvm::cl::opt clEnablePadConsumerFusion( static llvm::cl::opt clEnableAccelMicrokernels( "iree-llvmcpu-enable-accel-ukernels", - llvm::cl::desc("Flag to enable lowering to accelUkernels"), + llvm::cl::desc("Flag to enable lowering to accel microkernels"), llvm::cl::init(false)); static llvm::cl::opt clEnableMicrokernelsDecomposeLinalgGeneric( @@ -611,10 +611,20 @@ void addMmt4dTilingExpertPassPipeline(OpPassManager &passManager, OpPassManager &nestedModulePM = passManager.nest(); - if (enableMicrokernels) { + if (clEnableAccelMicrokernels) { + nestedModulePM.addNestedPass(createLLVMCPUTileAndFusePass( + static_cast(tilingConfig.getVectorCommonParallelLevel()))); + nestedModulePM.addNestedPass(createLLVMCPUTilePass( + static_cast(tilingConfig.getVectorReductionLevel()))); + nestedModulePM.addNestedPass( + createDecomposeBatchMmt4DOpsPass()); + nestedModulePM.addPass( + createLLVMCPULowerToAccelUKernelsPass()); + nestedModulePM.addNestedPass( + createConvertToDestinationPassingStylePass()); + } else if (enableMicrokernels) { nestedModulePM.addNestedPass( createDecomposeBatchMmt4DOpsPass()); - nestedModulePM.addPass(createLLVMCPULowerToAccelUKernelsPass()); nestedModulePM.addPass( createLLVMCPULowerToUKernelsPass(clSkipIntermediateRoundings)); } else { @@ -639,32 +649,6 @@ void addMmt4dTilingExpertPassPipeline(OpPassManager &passManager, } } -void addAccelMatmulExpertPassPipeline(OpPassManager &passManager, - TilingConfig &tilingConfig, - bool enableAccelMicrokernels) { - addTileAndDistributePasses(passManager); - - OpPassManager &nestedModulePM = passManager.nest(); - - if (enableAccelMicrokernels) { - nestedModulePM.addPass(createLLVMCPULowerToAccelUKernelsPass()); - } else { - nestedModulePM.addNestedPass(createLLVMCPUTileAndFusePass( - static_cast(tilingConfig.getVectorCommonParallelLevel()))); - nestedModulePM.addNestedPass(createLLVMCPUTilePass( - static_cast(tilingConfig.getVectorReductionLevel()))); - nestedModulePM.addNestedPass( - createGenericVectorizationPass()); - nestedModulePM.addNestedPass( - createHoistRedundantVectorTransfersPass()); - } - - nestedModulePM.addNestedPass(createCanonicalizerPass()); - nestedModulePM.addNestedPass(createCSEPass()); - - addBufferizePasses(nestedModulePM); -} - void addCPUDataTilingPipeline(OpPassManager &passManager, TilingConfig &tilingConfig, bool enableVectorMasking) { diff --git a/compiler/src/iree/compiler/Codegen/LLVMCPU/Passes.h b/compiler/src/iree/compiler/Codegen/LLVMCPU/Passes.h index d2e51001f3ecb..7f536ac1f4e47 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMCPU/Passes.h +++ b/compiler/src/iree/compiler/Codegen/LLVMCPU/Passes.h @@ -149,10 +149,6 @@ void addMmt4dTilingExpertPassPipeline(OpPassManager &passManager, TilingConfig &tilingConfig, bool enableMicrokernels); -void addAccelMatmulExpertPassPipeline(OpPassManager &passManager, - TilingConfig &tilingConfig, - bool enableAccelMicrokernels); - void addMultiTilingExpertPassPipeline( OpPassManager &passManager, TilingConfig &tilingConfig, bool enablePeeling, bool enableVectorMasking, bool lowerToAVX2, bool enableAArch64SSVE = false); diff --git a/compiler/src/iree/compiler/Codegen/LLVMCPU/test/lower_to_accel_ukernel_ops.mlir b/compiler/src/iree/compiler/Codegen/LLVMCPU/test/lower_to_accel_ukernel_ops.mlir index e41343f394d54..b1fc88824e635 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMCPU/test/lower_to_accel_ukernel_ops.mlir +++ b/compiler/src/iree/compiler/Codegen/LLVMCPU/test/lower_to_accel_ukernel_ops.mlir @@ -1,10 +1,12 @@ // RUN: iree-opt --pass-pipeline="builtin.module(func.func(iree-llvmcpu-lower-to-accel-ukernels,cse,canonicalize))" %s | FileCheck %s -func.func @matmul_f32f32f32(%arg0 : tensor, %arg1 : tensor, %arg2 : tensor) -> tensor { - %0 = linalg.matmul ins(%arg0, %arg1 : tensor, tensor) - outs(%arg2 : tensor) -> tensor - return %0 : tensor +func.func @mmt4d_f32f32f32(%arg0 : tensor<1x1x?x?xf32>, %arg1 : tensor<1x1x?x?xf32>, + %arg2 : tensor<1x1x?x?xf32>) -> tensor<1x1x?x?xf32> { + %0 = linalg.mmt4d ins(%arg0, %arg1 : tensor<1x1x?x?xf32>, tensor<1x1x?x?xf32>) + outs(%arg2 : tensor<1x1x?x?xf32>) -> tensor<1x1x?x?xf32> + return %0 : tensor<1x1x?x?xf32> } + // CHECK: func @matmul_f32f32f32( // CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor // CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor @@ -14,7 +16,7 @@ func.func @matmul_f32f32f32(%arg0 : tensor, %arg1 : tensor, %a // CHECK-DAG: %[[M:.+]] = tensor.dim %[[ARG0]], %[[C0]] // CHECK-DAG: %[[N:.+]] = tensor.dim %[[ARG1]], %[[C0]] // CHECK-DAG: %[[K:.+]] = tensor.dim %[[ARG1]], %[[C1]] -// CHECK: %[[MICRO_KERNEL:.+]] = iree_codegen.ukernel.generic "aie_matmul_f32" +// CHECK: %[[MICRO_KERNEL:.+]] = iree_codegen.ukernel.generic "accel_matmul_f32" // CHECK-SAME: ins(%[[ARG0]], %[[ARG1]] : // CHECK-SAME: outs(%[[ARG2]] : // CHECK-SAME: (%[[M]], %[[N]], %[[K]] : diff --git a/samples/custom_dispatch/cpu/plugin/CMakeLists.txt b/samples/custom_dispatch/cpu/plugin/CMakeLists.txt index cb675f0fcf08a..59a5c793f998e 100644 --- a/samples/custom_dispatch/cpu/plugin/CMakeLists.txt +++ b/samples/custom_dispatch/cpu/plugin/CMakeLists.txt @@ -21,7 +21,6 @@ target_include_directories(iree_samples_custom_dispatch_cpu_system_plugin ${IREE_SOURCE_DIR}/runtime/src/ ) -iree_add_all_subdirs() # NOTE: this is only required because we want this sample to run on all # platforms without needing to change the library name (libfoo.so/foo.dll). set_target_properties(iree_samples_custom_dispatch_cpu_system_plugin