Skip to content

Commit

Permalink
Use mmt4d path instead of defining linalg.matmul path.
Browse files Browse the repository at this point in the history
Fix lit test and pass to use mmt4d op.

(WIP) Use rank-reduced slices of operands in ukernel call.

Fix rank reduction and lit test.

Remove isInitializedToZero from accel codegen

Fix dims
  • Loading branch information
monorimet committed Sep 26, 2023
1 parent 229aede commit 58fa6dc
Show file tree
Hide file tree
Showing 8 changed files with 89 additions and 88 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -55,8 +55,6 @@ def SPIRV_WinogradVectorize

def VMVX_Default : I32EnumAttrCase<"VMVXDefault", 24>;

def CPU_AccelMatmulExpert
: I32EnumAttrCase<"AccelMatmulExpert", 25>;

def Linalg_TransformDialectCodegen
: I32EnumAttrCase<"TransformDialectCodegen", 100>;
Expand All @@ -78,7 +76,7 @@ def DispatchLoweringPassPipelineEnum
LLVMGPU_PackUnPack, LLVMGPU_MatmulTensorCoreMmaSync,
SPIRV_BaseLowering, SPIRV_BaseDistribute, SPIRV_BaseVectorize,
SPIRV_MatmulPromoteVectorize, SPIRV_CooperativeMatrixVectorize,
SPIRV_SubgroupReduce, SPIRV_WinogradVectorize, VMVX_Default, CPU_AccelMatmulExpert,
SPIRV_SubgroupReduce, SPIRV_WinogradVectorize, VMVX_Default,
// Transform dialect based codegen
Linalg_TransformDialectCodegen, None
]> {
Expand Down
14 changes: 1 addition & 13 deletions compiler/src/iree/compiler/Codegen/LLVMCPU/KernelDispatch.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1159,18 +1159,6 @@ static LogicalResult setRootConfig(func::FuncOp entryPointFn,
DispatchLoweringPassPipeline::Mmt4dTilingExpert);
}

/// Sets the lowering configuration for dispatch region for linalg.matmul root
/// op
static LogicalResult setRootConfig(func::FuncOp entryPointFn,
linalg::MatmulOp matmulOp) {
assert(!getLoweringConfig(matmulOp) && "expected lowering_config is not set");
SmallVector<int64_t> tileSizes;
tileSizes.push_back(1);
return setOpConfigAndEntryPointFnTranslation(
entryPointFn, matmulOp, tileSizes,
DispatchLoweringPassPipeline::AccelMatmulExpert);
}

/// Sets the lowering configuration for dispatch region for linalg.batch_mmt4d
/// root op
static LogicalResult setRootConfig(func::FuncOp entryPointFn,
Expand Down Expand Up @@ -1993,7 +1981,7 @@ setRootConfigImpl(func::FuncOp entryPointFn, Operation *op,
targetMLTransInfo);
})
.Case<IREE::LinalgExt::FftOp, tensor::PackOp, tensor::PadOp,
linalg::Mmt4DOp, linalg::MatmulOp, linalg::BatchMmt4DOp>(
linalg::Mmt4DOp, linalg::BatchMmt4DOp>(
[&](auto op) { return setRootConfig(entryPointFn, op); })
.Case<linalg::Conv2DNhwcHwcfOp, linalg::Conv2DNchwFchwOp,
linalg::PoolingNhwcSumOp, linalg::PoolingNhwcMaxOp,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -242,7 +242,6 @@ void LLVMCPULowerExecutableTargetPass::runOnOperation() {
(isAArch64(target) && hasAnySVEFeature(target)));

bool enableMicrokernels = hasMicrokernels(target);
bool enableAccelMicrokernels = isX86(target);
bool enableAArch64SSVE = isAArch64(target) && hasAnySVEFeature(target) &&
hasSMEFeature(target);
if (!testLoweringConfiguration) {
Expand Down Expand Up @@ -297,13 +296,6 @@ void LLVMCPULowerExecutableTargetPass::runOnOperation() {
tilingConfig, enableMicrokernels);
break;
}
case IREE::Codegen::DispatchLoweringPassPipeline::AccelMatmulExpert: {
TilingConfig tilingConfig = getTilingConfigForPipeline(moduleOp);
addAccelMatmulExpertPassPipeline(executableLoweringPipeline,
tilingConfig,
enableAccelMicrokernels);
break;
}
case IREE::Codegen::DispatchLoweringPassPipeline::CPUDataTiling: {
TilingConfig tilingConfig = getTilingConfigForPipeline(moduleOp);
addCPUDataTilingPipeline(executableLoweringPipeline, tilingConfig,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,16 +47,6 @@ class LLVMCPULowerToAccelUKernelsPass
}
};

/// Returns `true` if an `outsOperand` value is initialized to zero.
static bool isInitializedToZero(Value outsOperand) {
auto fillOp = outsOperand.getDefiningOp<linalg::FillOp>();
if (!fillOp)
return false;
Value fillVal = fillOp.getDpsInputOperand(0)->get();
return matchPattern(fillVal, m_Zero()) ||
matchPattern(fillVal, m_AnyZeroFloat());
}

/// Holds a function name and attributes.
struct FnNameAndDefAttrs {
std::string name;
Expand All @@ -82,35 +72,87 @@ getFnNameAndDefAttrs(const char *ukernelName, RewriterBase &rewriter,
return result;
}

/// Matches an (linalg.fill -> )? linalg.matmul operation sequence and converts
/// Matches an (linalg.fill -> )? linalg.mmt4d operation sequence and converts
/// it into a iree_codegen.ukernel.generic "accel_matmul_f32" operation, that is later lowered
/// into a call to the microkernel.
static FailureOr<IREE::Codegen::UKernelOpInterface>
matchDAGForUKernel(RewriterBase &rewriter, linalg::MatmulOp op) {
matchDAGForUKernel(RewriterBase &rewriter, linalg::Mmt4DOp op) {
Value lhs = op.getDpsInputOperand(0)->get();
Value rhs = op.getDpsInputOperand(1)->get();
Value out = op.getDpsInitOperand(0)->get();
auto lhsType = llvm::cast<ShapedType>(lhs.getType());
auto rhsType = llvm::cast<ShapedType>(rhs.getType());
auto outType = llvm::cast<ShapedType>(out.getType());

// Check if the accumulator is zero-filled.
if (isInitializedToZero(out)) {
// The plugin will not read the existing accumulator, so its defining op can
// be discarded.
if (auto fillOp = out.getDefiningOp<linalg::FillOp>()) {
out = fillOp.getDpsInitOperand(0)->get();
}
Type lhsElemType = lhsType.getElementType();
Type rhsElemType = rhsType.getElementType();
Type outElemType = outType.getElementType();
uint32_t flags = 0;
if (lhsElemType.isSignlessInteger(8) && rhsElemType.isSignlessInteger(8) &&
outElemType.isSignlessInteger(32)) {
flags = IREE_UK_FLAG_MMT4D_TYPE_I8I8I32;
} else if (lhsElemType.isF32() && rhsElemType.isF32() &&
outElemType.isF32()) {
flags = IREE_UK_FLAG_MMT4D_TYPE_F32F32F32;
} else {
return rewriter.notifyMatchFailure(
op, "unsupported combination of element types");
}

Location loc = op.getLoc();
Value m = rewriter.create<tensor::DimOp>(loc, lhs, 0);
Value n = rewriter.create<tensor::DimOp>(loc, rhs, 0);
Value k = rewriter.create<tensor::DimOp>(loc, rhs, 1);

if (outType.getShape()[0] != 1 || outType.getShape()[1] != 1) {
return rewriter.notifyMatchFailure(op, "outer dims need to be 1");
}

auto outTypeRanked = out.getType().cast<RankedTensorType>();
RankedTensorType intermediateOutType =
RankedTensorType::Builder(outTypeRanked).dropDim(0);
RankedTensorType reducedOutType =
RankedTensorType::Builder(intermediateOutType).dropDim(0);
Value reducedOut = tensor::createCanonicalRankReducingExtractSliceOp(
rewriter, loc, out, reducedOutType);

auto lhsTypeRanked = lhs.getType().cast<RankedTensorType>();
RankedTensorType intermediateLhsType =
RankedTensorType::Builder(lhsTypeRanked).dropDim(0);
RankedTensorType reducedLhsType =
RankedTensorType::Builder(intermediateLhsType).dropDim(0);
auto reducedLhs = tensor::createCanonicalRankReducingExtractSliceOp(
rewriter, loc, lhs, reducedLhsType);

auto rhsTypeRanked = rhs.getType().cast<RankedTensorType>();
RankedTensorType intermediateRhsType =
RankedTensorType::Builder(rhsTypeRanked).dropDim(0);
RankedTensorType reducedRhsType =
RankedTensorType::Builder(intermediateRhsType).dropDim(0);
auto reducedRhs = tensor::createCanonicalRankReducingExtractSliceOp(
rewriter, loc, rhs, reducedRhsType);
/*
auto getDimAsI32 = [](RewriterBase &rewriter, Location loc, Value value,
int dim) -> Value {
return rewriter.create<arith::IndexCastOp>(
loc, rewriter.getI32Type(),
rewriter.create<tensor::DimOp>(loc, value, dim));
};
Value m = getDimAsI32(rewriter, loc, reducedLhs, 0);
Value n = getDimAsI32(rewriter, loc, reducedRhs, 0);
Value k = getDimAsI32(rewriter, loc, reducedRhs, 1);
*/
Value m = rewriter.create<tensor::DimOp>(loc, reducedLhs, 0);
Value n = rewriter.create<tensor::DimOp>(loc, reducedRhs, 0);
Value k = rewriter.create<tensor::DimOp>(loc, reducedRhs, 1);
Value flagsVal = rewriter.create<arith::ConstantOp>(
loc, rewriter.getI32IntegerAttr(flags));
auto targetAttr = IREE::HAL::ExecutableTargetAttr::lookup(op);
auto fn = getFnNameAndDefAttrs("accel_matmul_f32", rewriter, targetAttr);
auto genericMicroKernelOp = rewriter.create<IREE::Codegen::UKernelGenericOp>(
loc, outType, fn.name, ValueRange{lhs, rhs}, out, ValueRange{m, n, k},
loc, reducedOutType, fn.name, ValueRange{reducedLhs, reducedRhs}, reducedOut,
ValueRange{m, n, k, flagsVal},
/*fn_def_attrs=*/rewriter.getDictionaryAttr(fn.defAttrs),
/*strided_outer_dims=*/rewriter.getIndexAttr(0));
auto insertSliceOp = tensor::createCanonicalRankReducingInsertSliceOp(
rewriter, loc, genericMicroKernelOp.getResult(0), out);
op.getResults()[0].replaceAllUsesWith(insertSliceOp);
return cast<IREE::Codegen::UKernelOpInterface>(
genericMicroKernelOp.getOperation());
}
Expand Down Expand Up @@ -143,7 +185,7 @@ void LLVMCPULowerToAccelUKernelsPass::runOnOperation() {
// Since microkernels are linked as bitcode, they will still undergo LTO-like
// optimization in their calling contexts, but we shouldn't expect this to
// achieve similar results as fusing structured ops.
patterns.insert<LowerToAccelUKernelPattern<linalg::MatmulOp>>(context);
patterns.insert<LowerToAccelUKernelPattern<linalg::Mmt4DOp>>(context);
if (failed(
applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)))) {
return signalPassFailure();
Expand Down
42 changes: 13 additions & 29 deletions compiler/src/iree/compiler/Codegen/LLVMCPU/Passes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ static llvm::cl::opt<bool> clEnablePadConsumerFusion(

static llvm::cl::opt<bool> clEnableAccelMicrokernels(
"iree-llvmcpu-enable-accel-ukernels",
llvm::cl::desc("Flag to enable lowering to accelUkernels"),
llvm::cl::desc("Flag to enable lowering to accel microkernels"),
llvm::cl::init(false));

static llvm::cl::opt<bool> clEnableMicrokernelsDecomposeLinalgGeneric(
Expand Down Expand Up @@ -611,10 +611,20 @@ void addMmt4dTilingExpertPassPipeline(OpPassManager &passManager,

OpPassManager &nestedModulePM = passManager.nest<ModuleOp>();

if (enableMicrokernels) {
if (clEnableAccelMicrokernels) {
nestedModulePM.addNestedPass<func::FuncOp>(createLLVMCPUTileAndFusePass(
static_cast<int64_t>(tilingConfig.getVectorCommonParallelLevel())));
nestedModulePM.addNestedPass<func::FuncOp>(createLLVMCPUTilePass(
static_cast<int64_t>(tilingConfig.getVectorReductionLevel())));
nestedModulePM.addNestedPass<func::FuncOp>(
createDecomposeBatchMmt4DOpsPass());
nestedModulePM.addPass(
createLLVMCPULowerToAccelUKernelsPass());
nestedModulePM.addNestedPass<func::FuncOp>(
createConvertToDestinationPassingStylePass());
} else if (enableMicrokernels) {
nestedModulePM.addNestedPass<func::FuncOp>(
createDecomposeBatchMmt4DOpsPass());
nestedModulePM.addPass(createLLVMCPULowerToAccelUKernelsPass());
nestedModulePM.addPass(
createLLVMCPULowerToUKernelsPass(clSkipIntermediateRoundings));
} else {
Expand All @@ -639,32 +649,6 @@ void addMmt4dTilingExpertPassPipeline(OpPassManager &passManager,
}
}

void addAccelMatmulExpertPassPipeline(OpPassManager &passManager,
TilingConfig &tilingConfig,
bool enableAccelMicrokernels) {
addTileAndDistributePasses(passManager);

OpPassManager &nestedModulePM = passManager.nest<ModuleOp>();

if (enableAccelMicrokernels) {
nestedModulePM.addPass(createLLVMCPULowerToAccelUKernelsPass());
} else {
nestedModulePM.addNestedPass<func::FuncOp>(createLLVMCPUTileAndFusePass(
static_cast<int64_t>(tilingConfig.getVectorCommonParallelLevel())));
nestedModulePM.addNestedPass<func::FuncOp>(createLLVMCPUTilePass(
static_cast<int64_t>(tilingConfig.getVectorReductionLevel())));
nestedModulePM.addNestedPass<func::FuncOp>(
createGenericVectorizationPass());
nestedModulePM.addNestedPass<func::FuncOp>(
createHoistRedundantVectorTransfersPass());
}

nestedModulePM.addNestedPass<func::FuncOp>(createCanonicalizerPass());
nestedModulePM.addNestedPass<func::FuncOp>(createCSEPass());

addBufferizePasses(nestedModulePM);
}

void addCPUDataTilingPipeline(OpPassManager &passManager,
TilingConfig &tilingConfig,
bool enableVectorMasking) {
Expand Down
4 changes: 0 additions & 4 deletions compiler/src/iree/compiler/Codegen/LLVMCPU/Passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -149,10 +149,6 @@ void addMmt4dTilingExpertPassPipeline(OpPassManager &passManager,
TilingConfig &tilingConfig,
bool enableMicrokernels);

void addAccelMatmulExpertPassPipeline(OpPassManager &passManager,
TilingConfig &tilingConfig,
bool enableAccelMicrokernels);

void addMultiTilingExpertPassPipeline(
OpPassManager &passManager, TilingConfig &tilingConfig, bool enablePeeling,
bool enableVectorMasking, bool lowerToAVX2, bool enableAArch64SSVE = false);
Expand Down
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
// RUN: iree-opt --pass-pipeline="builtin.module(func.func(iree-llvmcpu-lower-to-accel-ukernels,cse,canonicalize))" %s | FileCheck %s

func.func @matmul_f32f32f32(%arg0 : tensor<?x?xf32>, %arg1 : tensor<?x?xf32>, %arg2 : tensor<?x?xf32>) -> tensor<?x?xf32> {
%0 = linalg.matmul ins(%arg0, %arg1 : tensor<?x?xf32>, tensor<?x?xf32>)
outs(%arg2 : tensor<?x?xf32>) -> tensor<?x?xf32>
return %0 : tensor<?x?xf32>
func.func @mmt4d_f32f32f32(%arg0 : tensor<1x1x?x?xf32>, %arg1 : tensor<1x1x?x?xf32>,
%arg2 : tensor<1x1x?x?xf32>) -> tensor<1x1x?x?xf32> {
%0 = linalg.mmt4d ins(%arg0, %arg1 : tensor<1x1x?x?xf32>, tensor<1x1x?x?xf32>)
outs(%arg2 : tensor<1x1x?x?xf32>) -> tensor<1x1x?x?xf32>
return %0 : tensor<1x1x?x?xf32>
}

// CHECK: func @matmul_f32f32f32(
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<?x?xf32>
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor<?x?xf32>
Expand All @@ -14,7 +16,7 @@ func.func @matmul_f32f32f32(%arg0 : tensor<?x?xf32>, %arg1 : tensor<?x?xf32>, %a
// CHECK-DAG: %[[M:.+]] = tensor.dim %[[ARG0]], %[[C0]]
// CHECK-DAG: %[[N:.+]] = tensor.dim %[[ARG1]], %[[C0]]
// CHECK-DAG: %[[K:.+]] = tensor.dim %[[ARG1]], %[[C1]]
// CHECK: %[[MICRO_KERNEL:.+]] = iree_codegen.ukernel.generic "aie_matmul_f32"
// CHECK: %[[MICRO_KERNEL:.+]] = iree_codegen.ukernel.generic "accel_matmul_f32"
// CHECK-SAME: ins(%[[ARG0]], %[[ARG1]] :
// CHECK-SAME: outs(%[[ARG2]] :
// CHECK-SAME: (%[[M]], %[[N]], %[[K]] :
Expand Down
1 change: 0 additions & 1 deletion samples/custom_dispatch/cpu/plugin/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@ target_include_directories(iree_samples_custom_dispatch_cpu_system_plugin
${IREE_SOURCE_DIR}/runtime/src/
)

iree_add_all_subdirs()
# NOTE: this is only required because we want this sample to run on all
# platforms without needing to change the library name (libfoo.so/foo.dll).
set_target_properties(iree_samples_custom_dispatch_cpu_system_plugin
Expand Down

0 comments on commit 58fa6dc

Please sign in to comment.