Skip to content

Commit

Permalink
[LLVMCPU] Add additional level of tiling to the default
Browse files Browse the repository at this point in the history
  • Loading branch information
pashu123 committed Nov 10, 2024
1 parent 5c45591 commit 864d9df
Show file tree
Hide file tree
Showing 5 changed files with 21 additions and 8 deletions.
6 changes: 6 additions & 0 deletions compiler/src/iree/compiler/Codegen/LLVMCPU/KernelDispatch.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2508,6 +2508,12 @@ static LogicalResult setRootConfig(mlir::FunctionOpInterface entryPointFn,
SmallVector<int64_t> distTileSizes =
getDefaultDistributedLevelTileSizes(op, DistributionHeuristicConfig{});
TileSizesListType tileSizes = {distTileSizes};
if(auto linalgOp = dyn_cast<linalg::LinalgOp>(*op)){
SmallVector<int64_t> vecTileSizes = distTileSizes;
limitVectorTileSizes(linalgOp, vecTileSizes);
tileSizes.push_back(vecTileSizes);
}

return setOpConfigAndEntryPointFnTranslation(
entryPointFn, op, tileSizes, DispatchLoweringPassPipeline::CPUDefault);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ getRootLoweringConfig(FunctionOpInterface funcOp) {

static TilingConfig getTilingConfigForPipeline(FunctionOpInterface funcOp) {
auto maybeLoweringConfig = getRootLoweringConfig(funcOp);
llvm::errs()<<"Hey I am here";
assert(succeeded(maybeLoweringConfig) &&
"Pipeline requires a lowering config");
return TilingConfig(*maybeLoweringConfig);
Expand Down Expand Up @@ -122,9 +123,11 @@ void LLVMCPULowerExecutableTargetPass::runOnOperation() {
// No pipleline specified, nothing to do.
case IREE::Codegen::DispatchLoweringPassPipeline::None:
return;
case IREE::Codegen::DispatchLoweringPassPipeline::CPUDefault:
addCPUDefaultPassPipeline(pipeline);
case IREE::Codegen::DispatchLoweringPassPipeline::CPUDefault: {
TilingConfig tilingConfig = getTilingConfigForPipeline(funcOp);
addCPUDefaultPassPipeline(pipeline, tilingConfig);
break;
}
case IREE::Codegen::DispatchLoweringPassPipeline::
CPUBufferOpsTileAndVectorize: {
TilingConfig tilingConfig = getTilingConfigForPipeline(funcOp);
Expand Down
6 changes: 5 additions & 1 deletion compiler/src/iree/compiler/Codegen/LLVMCPU/Passes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -653,8 +653,12 @@ void addCPULinalgExtTileAndVectorizePipeline(
}
}

void addCPUDefaultPassPipeline(OpPassManager &funcPassManager) {
void addCPUDefaultPassPipeline(OpPassManager &funcPassManager, TilingConfig &tilingConfig) {
addTileAndDistributePasses(funcPassManager);
if(tilingConfig.getNumTilingLevels() > 1){
funcPassManager.addPass(createLLVMCPUTileAndFusePass(
tilingConfig.getVectorCommonParallelLevel()));
}
addCPUBufferizePasses(funcPassManager);
}

Expand Down
2 changes: 1 addition & 1 deletion compiler/src/iree/compiler/Codegen/LLVMCPU/Passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ void addCPULinalgExtTileAndVectorizePipeline(
/// Populates the passes to lower to scalars operations for linalg based
/// code-generation. This pipeline does not vectorize, but instead just
/// converts to memrefs
void addCPUDefaultPassPipeline(OpPassManager &funcPassManager);
void addCPUDefaultPassPipeline(OpPassManager &funcPassManager, TilingConfig &tilingConfig);

void addConvTileAndDecomposeExpertPassPipeline(
OpPassManager &funcPassManager, TilingConfig &tilingConfig,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -550,10 +550,10 @@ isFusableWithConsumer(OpOperand &fusedOperand,
// TODO: Enable grouped convolution and depth wise pooling fusion.
// Rightnow, this is going through the default CPU pipeline and not through
// CONVTilingExpert.
if (isa<linalg::Conv2DNgchwFgchwOp, linalg::Conv2DNgchwGfchwOp,
linalg::PoolingNdhwcSumOp>(producer)) {
return false;
}
// if (isa<linalg::Conv2DNgchwFgchwOp, linalg::Conv2DNgchwGfchwOp,
// linalg::PoolingNdhwcSumOp>(producer)) {
// return false;
// }

auto producerFusionOp =
dyn_cast<IREE::LinalgExt::LinalgFusionOpInterface>(producer);
Expand Down

0 comments on commit 864d9df

Please sign in to comment.