Skip to content

Commit

Permalink
Use mmt4d path instead of defining linalg.matmul path.
Browse files Browse the repository at this point in the history
Fix lit test and pass to use mmt4d op.

(WIP) Use rank-reduced slices of operands in ukernel call.

Fix rank reduction and lit test.
  • Loading branch information
monorimet committed Sep 26, 2023
1 parent 5dd79be commit 4174ffe
Show file tree
Hide file tree
Showing 8 changed files with 90 additions and 70 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -55,8 +55,6 @@ def SPIRV_WinogradVectorize

def VMVX_Default : I32EnumAttrCase<"VMVXDefault", 24>;

def CPU_AccelMatmulExpert
: I32EnumAttrCase<"AccelMatmulExpert", 25>;

def Linalg_TransformDialectCodegen
: I32EnumAttrCase<"TransformDialectCodegen", 100>;
Expand All @@ -78,7 +76,7 @@ def DispatchLoweringPassPipelineEnum
LLVMGPU_PackUnPack, LLVMGPU_MatmulTensorCoreMmaSync,
SPIRV_BaseLowering, SPIRV_BaseDistribute, SPIRV_BaseVectorize,
SPIRV_MatmulPromoteVectorize, SPIRV_CooperativeMatrixVectorize,
SPIRV_SubgroupReduce, SPIRV_WinogradVectorize, VMVX_Default, CPU_AccelMatmulExpert,
SPIRV_SubgroupReduce, SPIRV_WinogradVectorize, VMVX_Default,
// Transform dialect based codegen
Linalg_TransformDialectCodegen, None
]> {
Expand Down
14 changes: 1 addition & 13 deletions compiler/src/iree/compiler/Codegen/LLVMCPU/KernelDispatch.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1159,18 +1159,6 @@ static LogicalResult setRootConfig(func::FuncOp entryPointFn,
DispatchLoweringPassPipeline::Mmt4dTilingExpert);
}

/// Sets the lowering configuration for dispatch region for linalg.matmul root
/// op
static LogicalResult setRootConfig(func::FuncOp entryPointFn,
linalg::MatmulOp matmulOp) {
assert(!getLoweringConfig(matmulOp) && "expected lowering_config is not set");
SmallVector<int64_t> tileSizes;
tileSizes.push_back(1);
return setOpConfigAndEntryPointFnTranslation(
entryPointFn, matmulOp, tileSizes,
DispatchLoweringPassPipeline::AccelMatmulExpert);
}

/// Sets the lowering configuration for dispatch region for linalg.batch_mmt4d
/// root op
static LogicalResult setRootConfig(func::FuncOp entryPointFn,
Expand Down Expand Up @@ -1993,7 +1981,7 @@ setRootConfigImpl(func::FuncOp entryPointFn, Operation *op,
targetMLTransInfo);
})
.Case<IREE::LinalgExt::FftOp, tensor::PackOp, tensor::PadOp,
linalg::Mmt4DOp, linalg::MatmulOp, linalg::BatchMmt4DOp>(
linalg::Mmt4DOp, linalg::BatchMmt4DOp>(
[&](auto op) { return setRootConfig(entryPointFn, op); })
.Case<linalg::Conv2DNhwcHwcfOp, linalg::Conv2DNchwFchwOp,
linalg::PoolingNhwcSumOp, linalg::PoolingNhwcMaxOp,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -242,7 +242,6 @@ void LLVMCPULowerExecutableTargetPass::runOnOperation() {
(isAArch64(target) && hasAnySVEFeature(target)));

bool enableMicrokernels = hasMicrokernels(target);
bool enableAccelMicrokernels = isX86(target);
bool enableAArch64SSVE = isAArch64(target) && hasAnySVEFeature(target) &&
hasSMEFeature(target);
if (!testLoweringConfiguration) {
Expand Down Expand Up @@ -297,13 +296,6 @@ void LLVMCPULowerExecutableTargetPass::runOnOperation() {
tilingConfig, enableMicrokernels);
break;
}
case IREE::Codegen::DispatchLoweringPassPipeline::AccelMatmulExpert: {
TilingConfig tilingConfig = getTilingConfigForPipeline(moduleOp);
addAccelMatmulExpertPassPipeline(executableLoweringPipeline,
tilingConfig,
enableAccelMicrokernels);
break;
}
case IREE::Codegen::DispatchLoweringPassPipeline::CPUDataTiling: {
TilingConfig tilingConfig = getTilingConfigForPipeline(moduleOp);
addCPUDataTilingPipeline(executableLoweringPipeline, tilingConfig,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -82,15 +82,31 @@ getFnNameAndDefAttrs(const char *ukernelName, RewriterBase &rewriter,
return result;
}

/// Matches an (linalg.fill -> )? linalg.matmul operation sequence and converts
/// Matches an (linalg.fill -> )? linalg.mmt4d operation sequence and converts
/// it into a iree_codegen.ukernel.generic "accel_matmul_f32" operation, that is later lowered
/// into a call to the microkernel.
static FailureOr<IREE::Codegen::UKernelOpInterface>
matchDAGForUKernel(RewriterBase &rewriter, linalg::MatmulOp op) {
matchDAGForUKernel(RewriterBase &rewriter, linalg::Mmt4DOp op) {
Value lhs = op.getDpsInputOperand(0)->get();
Value rhs = op.getDpsInputOperand(1)->get();
Value out = op.getDpsInitOperand(0)->get();
auto lhsType = llvm::cast<ShapedType>(lhs.getType());
auto rhsType = llvm::cast<ShapedType>(rhs.getType());
auto outType = llvm::cast<ShapedType>(out.getType());
Type lhsElemType = lhsType.getElementType();
Type rhsElemType = rhsType.getElementType();
Type outElemType = outType.getElementType();
uint32_t flags = 0;
if (lhsElemType.isSignlessInteger(8) && rhsElemType.isSignlessInteger(8) &&
outElemType.isSignlessInteger(32)) {
flags = IREE_UK_FLAG_MMT4D_TYPE_I8I8I32;
} else if (lhsElemType.isF32() && rhsElemType.isF32() &&
outElemType.isF32()) {
flags = IREE_UK_FLAG_MMT4D_TYPE_F32F32F32;
} else {
return rewriter.notifyMatchFailure(
op, "unsupported combination of element types");
}

// Check if the accumulator is zero-filled.
if (isInitializedToZero(out)) {
Expand All @@ -99,18 +115,63 @@ matchDAGForUKernel(RewriterBase &rewriter, linalg::MatmulOp op) {
if (auto fillOp = out.getDefiningOp<linalg::FillOp>()) {
out = fillOp.getDpsInitOperand(0)->get();
}
} else {
// Tell the mmt4d op to read the existing accumulator.
flags |= IREE_UK_FLAG_MMT4D_ACCUMULATE;
}

Location loc = op.getLoc();
Value m = rewriter.create<tensor::DimOp>(loc, lhs, 0);
Value n = rewriter.create<tensor::DimOp>(loc, rhs, 0);
Value k = rewriter.create<tensor::DimOp>(loc, rhs, 1);

if (outType.getShape()[0] != 1 || outType.getShape()[1] != 1) {
return rewriter.notifyMatchFailure(op, "outer dims need to be 1");
}

auto outTypeRanked = out.getType().cast<RankedTensorType>();
RankedTensorType intermediateOutType =
RankedTensorType::Builder(outTypeRanked).dropDim(0);
RankedTensorType reducedOutType =
RankedTensorType::Builder(intermediateOutType).dropDim(0);
Value reducedOut = tensor::createCanonicalRankReducingExtractSliceOp(
rewriter, loc, out, reducedOutType);

auto lhsTypeRanked = lhs.getType().cast<RankedTensorType>();
RankedTensorType intermediateLhsType =
RankedTensorType::Builder(lhsTypeRanked).dropDim(0);
RankedTensorType reducedLhsType =
RankedTensorType::Builder(intermediateLhsType).dropDim(0);
auto reducedLhs = tensor::createCanonicalRankReducingExtractSliceOp(
rewriter, loc, lhs, reducedLhsType);

auto rhsTypeRanked = rhs.getType().cast<RankedTensorType>();
RankedTensorType intermediateRhsType =
RankedTensorType::Builder(rhsTypeRanked).dropDim(0);
RankedTensorType reducedRhsType =
RankedTensorType::Builder(intermediateRhsType).dropDim(0);
auto reducedRhs = tensor::createCanonicalRankReducingExtractSliceOp(
rewriter, loc, rhs, reducedRhsType);

auto getDimAsI32 = [](RewriterBase &rewriter, Location loc, Value value,
int dim) -> Value {
return rewriter.create<arith::IndexCastOp>(
loc, rewriter.getI32Type(),
rewriter.create<tensor::DimOp>(loc, value, dim));
};
Value m = getDimAsI32(rewriter, loc, reducedLhs, 0);
Value n = getDimAsI32(rewriter, loc, reducedRhs, 0);
Value k = getDimAsI32(rewriter, loc, reducedRhs, 1);

Value flagsVal = rewriter.create<arith::ConstantOp>(
loc, rewriter.getI32IntegerAttr(flags));
auto targetAttr = IREE::HAL::ExecutableTargetAttr::lookup(op);
auto fn = getFnNameAndDefAttrs("accel_matmul_f32", rewriter, targetAttr);
auto genericMicroKernelOp = rewriter.create<IREE::Codegen::UKernelGenericOp>(
loc, outType, fn.name, ValueRange{lhs, rhs}, out, ValueRange{m, n, k},
loc, reducedOutType, fn.name, ValueRange{reducedLhs, reducedRhs}, reducedOut,
ValueRange{m, n, k, flagsVal},
/*fn_def_attrs=*/rewriter.getDictionaryAttr(fn.defAttrs),
/*strided_outer_dims=*/rewriter.getIndexAttr(0));
auto insertSliceOp = tensor::createCanonicalRankReducingInsertSliceOp(
rewriter, loc, genericMicroKernelOp.getResult(0), out);
op.getResults()[0].replaceAllUsesWith(insertSliceOp);
return cast<IREE::Codegen::UKernelOpInterface>(
genericMicroKernelOp.getOperation());
}
Expand Down Expand Up @@ -143,7 +204,7 @@ void LLVMCPULowerToAccelUKernelsPass::runOnOperation() {
// Since microkernels are linked as bitcode, they will still undergo LTO-like
// optimization in their calling contexts, but we shouldn't expect this to
// achieve similar results as fusing structured ops.
patterns.insert<LowerToAccelUKernelPattern<linalg::MatmulOp>>(context);
patterns.insert<LowerToAccelUKernelPattern<linalg::Mmt4DOp>>(context);
if (failed(
applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)))) {
return signalPassFailure();
Expand Down
42 changes: 13 additions & 29 deletions compiler/src/iree/compiler/Codegen/LLVMCPU/Passes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ static llvm::cl::opt<bool> clEnablePadConsumerFusion(

static llvm::cl::opt<bool> clEnableAccelMicrokernels(
"iree-llvmcpu-enable-accel-ukernels",
llvm::cl::desc("Flag to enable lowering to accelUkernels"),
llvm::cl::desc("Flag to enable lowering to accel microkernels"),
llvm::cl::init(false));

static llvm::cl::opt<bool> clEnableMicrokernelsDecomposeLinalgGeneric(
Expand Down Expand Up @@ -611,10 +611,20 @@ void addMmt4dTilingExpertPassPipeline(OpPassManager &passManager,

OpPassManager &nestedModulePM = passManager.nest<ModuleOp>();

if (enableMicrokernels) {
if (clEnableAccelMicrokernels) {
nestedModulePM.addNestedPass<func::FuncOp>(createLLVMCPUTileAndFusePass(
static_cast<int64_t>(tilingConfig.getVectorCommonParallelLevel())));
nestedModulePM.addNestedPass<func::FuncOp>(createLLVMCPUTilePass(
static_cast<int64_t>(tilingConfig.getVectorReductionLevel())));
nestedModulePM.addNestedPass<func::FuncOp>(
createDecomposeBatchMmt4DOpsPass());
nestedModulePM.addPass(
createLLVMCPULowerToAccelUKernelsPass());
nestedModulePM.addNestedPass<func::FuncOp>(
createConvertToDestinationPassingStylePass());
} else if (enableMicrokernels) {
nestedModulePM.addNestedPass<func::FuncOp>(
createDecomposeBatchMmt4DOpsPass());
nestedModulePM.addPass(createLLVMCPULowerToAccelUKernelsPass());
nestedModulePM.addPass(
createLLVMCPULowerToUKernelsPass(clSkipIntermediateRoundings));
} else {
Expand All @@ -639,32 +649,6 @@ void addMmt4dTilingExpertPassPipeline(OpPassManager &passManager,
}
}

void addAccelMatmulExpertPassPipeline(OpPassManager &passManager,
TilingConfig &tilingConfig,
bool enableAccelMicrokernels) {
addTileAndDistributePasses(passManager);

OpPassManager &nestedModulePM = passManager.nest<ModuleOp>();

if (enableAccelMicrokernels) {
nestedModulePM.addPass(createLLVMCPULowerToAccelUKernelsPass());
} else {
nestedModulePM.addNestedPass<func::FuncOp>(createLLVMCPUTileAndFusePass(
static_cast<int64_t>(tilingConfig.getVectorCommonParallelLevel())));
nestedModulePM.addNestedPass<func::FuncOp>(createLLVMCPUTilePass(
static_cast<int64_t>(tilingConfig.getVectorReductionLevel())));
nestedModulePM.addNestedPass<func::FuncOp>(
createGenericVectorizationPass());
nestedModulePM.addNestedPass<func::FuncOp>(
createHoistRedundantVectorTransfersPass());
}

nestedModulePM.addNestedPass<func::FuncOp>(createCanonicalizerPass());
nestedModulePM.addNestedPass<func::FuncOp>(createCSEPass());

addBufferizePasses(nestedModulePM);
}

void addCPUDataTilingPipeline(OpPassManager &passManager,
TilingConfig &tilingConfig,
bool enableVectorMasking) {
Expand Down
4 changes: 0 additions & 4 deletions compiler/src/iree/compiler/Codegen/LLVMCPU/Passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -149,10 +149,6 @@ void addMmt4dTilingExpertPassPipeline(OpPassManager &passManager,
TilingConfig &tilingConfig,
bool enableMicrokernels);

void addAccelMatmulExpertPassPipeline(OpPassManager &passManager,
TilingConfig &tilingConfig,
bool enableAccelMicrokernels);

void addMultiTilingExpertPassPipeline(
OpPassManager &passManager, TilingConfig &tilingConfig, bool enablePeeling,
bool enableVectorMasking, bool lowerToAVX2, bool enableAArch64SSVE = false);
Expand Down
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
// RUN: iree-opt --pass-pipeline="builtin.module(func.func(iree-llvmcpu-lower-to-accel-ukernels,cse,canonicalize))" %s | FileCheck %s

func.func @matmul_f32f32f32(%arg0 : tensor<?x?xf32>, %arg1 : tensor<?x?xf32>, %arg2 : tensor<?x?xf32>) -> tensor<?x?xf32> {
%0 = linalg.matmul ins(%arg0, %arg1 : tensor<?x?xf32>, tensor<?x?xf32>)
outs(%arg2 : tensor<?x?xf32>) -> tensor<?x?xf32>
return %0 : tensor<?x?xf32>
func.func @mmt4d_f32f32f32(%arg0 : tensor<1x1x?x?xf32>, %arg1 : tensor<1x1x?x?xf32>,
%arg2 : tensor<1x1x?x?xf32>) -> tensor<1x1x?x?xf32> {
%0 = linalg.mmt4d ins(%arg0, %arg1 : tensor<1x1x?x?xf32>, tensor<1x1x?x?xf32>)
outs(%arg2 : tensor<1x1x?x?xf32>) -> tensor<1x1x?x?xf32>
return %0 : tensor<1x1x?x?xf32>
}

// CHECK: func @matmul_f32f32f32(
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<?x?xf32>
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor<?x?xf32>
Expand All @@ -14,7 +16,7 @@ func.func @matmul_f32f32f32(%arg0 : tensor<?x?xf32>, %arg1 : tensor<?x?xf32>, %a
// CHECK-DAG: %[[M:.+]] = tensor.dim %[[ARG0]], %[[C0]]
// CHECK-DAG: %[[N:.+]] = tensor.dim %[[ARG1]], %[[C0]]
// CHECK-DAG: %[[K:.+]] = tensor.dim %[[ARG1]], %[[C1]]
// CHECK: %[[MICRO_KERNEL:.+]] = iree_codegen.ukernel.generic "aie_matmul_f32"
// CHECK: %[[MICRO_KERNEL:.+]] = iree_codegen.ukernel.generic "accel_matmul_f32"
// CHECK-SAME: ins(%[[ARG0]], %[[ARG1]] :
// CHECK-SAME: outs(%[[ARG2]] :
// CHECK-SAME: (%[[M]], %[[N]], %[[K]] :
Expand Down
1 change: 0 additions & 1 deletion samples/custom_dispatch/cpu/plugin/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@ target_include_directories(iree_samples_custom_dispatch_cpu_system_plugin
${IREE_SOURCE_DIR}/runtime/src/
)

iree_add_all_subdirs()
# NOTE: this is only required because we want this sample to run on all
# platforms without needing to change the library name (libfoo.so/foo.dll).
set_target_properties(iree_samples_custom_dispatch_cpu_system_plugin
Expand Down

0 comments on commit 4174ffe

Please sign in to comment.