forked from iree-org/iree
-
Notifications
You must be signed in to change notification settings - Fork 11
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[VectorDistribution] Set layouts before generalization and folding (i…
…ree-org#18186) This patch allows setting anchors for linalg operations before generalization + unit dims folding. This patch introduces 2 things: - Unit dim folding for to_layout ops: required for unit dim folding which is further required for 1x1ConvToMatmul - Teaching ConfigureTensorLayout to set layouts for convolutions with unit filter dims
- Loading branch information
Showing
13 changed files
with
261 additions
and
32 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
105 changes: 105 additions & 0 deletions
105
...er/src/iree/compiler/Codegen/Dialect/VectorExt/Transforms/VectorExtFoldUnitExtentDims.cpp
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,105 @@ | ||
// Copyright 2024 The IREE Authors | ||
// | ||
// Licensed under the Apache License v2.0 with LLVM Exceptions. | ||
// See https://llvm.org/LICENSE.txt for license information. | ||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | ||
|
||
#include "iree/compiler/Codegen/Dialect/VectorExt/IR/VectorExtDialect.h" | ||
#include "iree/compiler/Codegen/Dialect/VectorExt/Transforms/Passes.h" | ||
#include "mlir/Dialect/Tensor/IR/Tensor.h" | ||
#include "mlir/IR/Builders.h" | ||
#include "mlir/Transforms/GreedyPatternRewriteDriver.h" | ||
|
||
namespace mlir::iree_compiler::IREE::VectorExt { | ||
|
||
#define GEN_PASS_DEF_VECTOREXTFOLDUNITEXTENTDIMSPASS | ||
#include "iree/compiler/Codegen/Dialect/VectorExt/Transforms/Passes.h.inc" | ||
|
||
namespace { | ||
|
||
struct DropToLayoutUnitDims final | ||
: OpRewritePattern<IREE::VectorExt::ToLayoutOp> { | ||
using OpRewritePattern::OpRewritePattern; | ||
|
||
LogicalResult matchAndRewrite(IREE::VectorExt::ToLayoutOp toLayoutOp, | ||
PatternRewriter &rewriter) const override { | ||
if (!toLayoutOp.hasTensorSemantics()) { | ||
return rewriter.notifyMatchFailure(toLayoutOp, | ||
"requires tensor semanticS"); | ||
} | ||
|
||
Location loc = toLayoutOp.getLoc(); | ||
ShapedType inputTy = toLayoutOp.getType(); | ||
ArrayRef<int64_t> shape = inputTy.getShape(); | ||
|
||
// Find list of dims to drop and the target shape. | ||
SmallVector<bool> unitDims(shape.size(), false); | ||
SmallVector<int64_t> targetShape; | ||
bool hasUnitDims = false; | ||
for (auto [idx, size] : llvm::enumerate(shape)) { | ||
if (size == 1) { | ||
unitDims[idx] = true; | ||
hasUnitDims = true; | ||
continue; | ||
} | ||
targetShape.push_back(size); | ||
} | ||
|
||
if (!hasUnitDims) { | ||
return rewriter.notifyMatchFailure(toLayoutOp, "no unit dims present"); | ||
} | ||
|
||
// Drop unit dims using extract_slice. | ||
FailureOr<Value> rankReducingExtract = | ||
tensor::ExtractSliceOp::rankReduceIfNeeded( | ||
rewriter, loc, toLayoutOp.getInput(), targetShape); | ||
assert(succeeded(rankReducingExtract) && "not a unit-extent collapse"); | ||
|
||
// Find the rank reduced layout. | ||
VectorLayoutInterface newLayout = toLayoutOp.getLayout().project(unitDims); | ||
|
||
Value rankReducedValue = rankReducingExtract.value(); | ||
auto newToLayoutOp = rewriter.create<IREE::VectorExt::ToLayoutOp>( | ||
loc, rankReducedValue.getType(), rankReducedValue, newLayout, | ||
toLayoutOp.getSharedMemoryConversion()); | ||
newToLayoutOp->setDiscardableAttrs( | ||
toLayoutOp->getDiscardableAttrDictionary()); | ||
|
||
// Expand to preserve output shape using insert_slice. | ||
// Here, since the shape comes from the result of a to_layout op, it will | ||
// always be static. | ||
Value dest = | ||
rewriter.create<tensor::EmptyOp>(loc, shape, inputTy.getElementType()); | ||
|
||
int64_t rank = inputTy.getRank(); | ||
SmallVector<OpFoldResult> offsets(rank, rewriter.getIndexAttr(0)); | ||
SmallVector<OpFoldResult> sizes = | ||
tensor::getMixedSizes(rewriter, loc, dest); | ||
SmallVector<OpFoldResult> strides(rank, rewriter.getIndexAttr(1)); | ||
rewriter.replaceOpWithNewOp<tensor::InsertSliceOp>( | ||
toLayoutOp, newToLayoutOp.getResult(), dest, offsets, sizes, strides); | ||
|
||
return success(); | ||
} | ||
}; | ||
|
||
} // namespace | ||
|
||
namespace { | ||
struct VectorExtFoldUnitExtentDimsPass final | ||
: impl::VectorExtFoldUnitExtentDimsPassBase< | ||
VectorExtFoldUnitExtentDimsPass> { | ||
void runOnOperation() override { | ||
|
||
MLIRContext *ctx = &getContext(); | ||
RewritePatternSet patterns(ctx); | ||
patterns.add<DropToLayoutUnitDims>(ctx); | ||
if (failed(applyPatternsAndFoldGreedily(getOperation(), | ||
std::move(patterns)))) { | ||
return signalPassFailure(); | ||
} | ||
} | ||
}; | ||
} // namespace | ||
|
||
} // namespace mlir::iree_compiler::IREE::VectorExt |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.