Skip to content

Commit

Permalink
[Cleanup] Retire filter-based vectorization patterns. (iree-org#15185)
Browse files Browse the repository at this point in the history
The last usage is in GPUDistributeSharedMemoryCopy. The revision
replaces it with calling vectorize method in function walk.
  • Loading branch information
hanhanW authored Oct 18, 2023
1 parent 6b5b989 commit a9d7aa5
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 97 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
#include <algorithm>
#include <numeric>

#include "iree-dialects/Dialect/LinalgExt/Passes/Passes.h"
#include "iree-dialects/Dialect/LinalgExt/Transforms/Transforms.h"
#include "iree/compiler/Codegen/Common/GPU/PassDetail.h"
#include "iree/compiler/Codegen/Common/GPU/Passes.h"
Expand All @@ -29,9 +28,6 @@

#define DEBUG_TYPE "iree-codegen-gpu-distribute-shared-memory-copy"

using mlir::iree_compiler::IREE::LinalgExt::LinalgVectorizationPattern;
using mlir::iree_compiler::IREE::LinalgExt::VectorizationPatterns;

/// Prints the given `funcOp` after a leading `step` comment header.
void debugPrint(mlir::func::FuncOp funcOp, const char *step) {
LLVM_DEBUG({
Expand Down Expand Up @@ -274,14 +270,17 @@ static void populateTilingAndDistribute(RewritePatternSet &patterns,
StringAttr::get(patterns.getContext(), kCopyDistributed)));
}

static void populateVectorizationPatterns(RewritePatternSet &patterns) {
VectorizationPatterns<linalg::GenericOp>::insert(
patterns, IREE::LinalgExt::LinalgVectorizationOptions(),
IREE::LinalgExt::LinalgTransformationFilter(
{StringAttr::get(patterns.getContext(),
getCopyToWorkgroupMemoryMarker()),
StringAttr::get(patterns.getContext(), kCopyDistributed)},
std::nullopt));
static void vectorizeDistributedCopies(func::FuncOp funcOp) {
IRRewriter rewriter(funcOp.getContext());
SmallVector<linalg::GenericOp> candidates;
funcOp.walk([&](linalg::GenericOp op) { candidates.push_back(op); });
for (auto op : candidates) {
SmallVector<int64_t> vectorSizes;
SmallVector<bool> scalableVecDims;
scalableVecDims.resize(vectorSizes.size());
(void)linalg::vectorize(rewriter, op, vectorSizes, scalableVecDims,
/*vectorizeGatherAccesses=*/true);
};
}

/// Return a flattened Id Value by combining the 3D gpu thread IDs.
Expand Down Expand Up @@ -436,12 +435,7 @@ class GPUDistributeSharedMemoryCopyPass
debugPrint(funcOp, "After step 2: thread distribution");

// Step 3. Vectorize the distributed copies.
RewritePatternSet vectorizationPatterns(context);
populateVectorizationPatterns(vectorizationPatterns);
if (failed(applyPatternsAndFoldGreedily(
funcOp, std::move(vectorizationPatterns)))) {
return signalPassFailure();
}
vectorizeDistributedCopies(funcOp);
debugPrint(funcOp, "After step 3: vectorization");

// Step4. Finally unroll all the loop created
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -169,59 +169,6 @@ class TilingPatterns<OpTy, OpTypes...> {
}
};

///
/// Linalg vectorization patterns.
///
/// `filter` controls LinalgTransformMarker matching and update when specified.
/// See `vectorizeLinalgOp` for more details.
struct LinalgVectorizationPattern
: public OpInterfaceRewritePattern<linalg::LinalgOp> {
/// Construct a generic pattern applied to all LinalgOp that verify `filter`.
LinalgVectorizationPattern(
MLIRContext *context,
LinalgVectorizationOptions opts = LinalgVectorizationOptions(),
LinalgTransformationFilter f = LinalgTransformationFilter(),
PatternBenefit benefit = 1);

/// Construct a pattern specifically applied to `opName`.
LinalgVectorizationPattern(
StringRef opName, MLIRContext *context,
LinalgVectorizationOptions opts = LinalgVectorizationOptions(),
LinalgTransformationFilter f = LinalgTransformationFilter(),
PatternBenefit benefit = 1);

LogicalResult matchAndRewrite(linalg::LinalgOp linalgOp,
PatternRewriter &rewriter) const override;

private:
/// LinalgTransformMarker handles special attribute manipulations.
LinalgVectorizationOptions options;
LinalgTransformationFilter filter;
};

template <typename... OpTypes>
class VectorizationPatterns;

template <>
class VectorizationPatterns<> {
public:
static void insert(RewritePatternSet &patterns,
const LinalgVectorizationOptions &opts,
const LinalgTransformationFilter &f) {}
};

template <typename OpTy, typename... OpTypes>
class VectorizationPatterns<OpTy, OpTypes...> {
public:
static void insert(RewritePatternSet &patterns,
const LinalgVectorizationOptions &opts,
const LinalgTransformationFilter &f) {
patterns.add<LinalgVectorizationPattern>(OpTy::getOperationName(),
patterns.getContext(), opts, f);
VectorizationPatterns<OpTypes...>::insert(patterns, opts, f);
}
};

///
/// Linalg promotion patterns.
///
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -100,32 +100,6 @@ LinalgTilingPattern::returningMatchAndRewrite(linalg::LinalgOp op,
return res;
}

LinalgVectorizationPattern::LinalgVectorizationPattern(
MLIRContext *context, LinalgVectorizationOptions opts,
LinalgExt::LinalgTransformationFilter f, PatternBenefit benefit)
: OpInterfaceRewritePattern<linalg::LinalgOp>(context, benefit),
options(std::move(opts)), filter(std::move(f)) {}

LinalgVectorizationPattern::LinalgVectorizationPattern(
StringRef opName, MLIRContext *context, LinalgVectorizationOptions opts,
LinalgExt::LinalgTransformationFilter f, PatternBenefit benefit)
: OpInterfaceRewritePattern<linalg::LinalgOp>(context, benefit),
options(std::move(opts)), filter(f.addOpNameFilter(opName)) {}

LogicalResult
LinalgVectorizationPattern::matchAndRewrite(linalg::LinalgOp linalgOp,
PatternRewriter &rewriter) const {
if (failed(filter.checkAndNotify(rewriter, linalgOp)))
return failure();
SmallVector<int64_t> vectorSizes;
if (options.enableVectorMasking)
vectorSizes.append(options.vectorSizeComputationFunction(
linalgOp, options.canonicalVectorSizes));
SmallVector<bool> scalableVecDims(vectorSizes.size(), false);
return vectorize(rewriter, linalgOp, vectorSizes, scalableVecDims,
options.vectorizeGatherAccesses);
}

} // namespace LinalgExt
} // namespace IREE
} // namespace iree_compiler
Expand Down

0 comments on commit a9d7aa5

Please sign in to comment.