Skip to content

Commit

Permalink
Retire downstream constant packing folder (#921)
Browse files Browse the repository at this point in the history
Replaces our downstream implementation of constant packing folder with
upstream patterns. The existing pass selects pack ops to fold and
otherwise acts as a small wrapper for the upstream logic.

Also, there was a bug in the element index calculation. After the
replacement, folded constants match the outputs produced by runtime
packing.
  • Loading branch information
adam-smnk authored Jun 7, 2024
1 parent 46eff69 commit f307b5a
Show file tree
Hide file tree
Showing 5 changed files with 120 additions and 172 deletions.
3 changes: 3 additions & 0 deletions include/TPP/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -216,6 +216,9 @@ def ConstantFoldPack : Pass<"constant-fold-pack", "ModuleOp"> {
let description = [{
Reduce pack overhead by folding tensor.pack into constant tensors.
}];
let dependentDialects = ["linalg::LinalgDialect",
"tensor::TensorDialect",
"arith::ArithDialect"];
}

def ElementWiseFusion : Pass<"element-wise-fusion", "func::FuncOp"> {
Expand Down
210 changes: 50 additions & 160 deletions lib/TPP/Transforms/ConstantFoldPack.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,11 @@
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Utils/IndexingUtils.h"
#include "mlir/IR/Threading.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "llvm/Support/Debug.h"

using namespace mlir;

Expand All @@ -25,179 +25,69 @@ namespace tpp {
} // namespace tpp
} // namespace mlir

#define DEBUG_TYPE "fold-pack-into-cst"

namespace {

struct ConstantFoldPack
: public tpp::impl::ConstantFoldPackBase<ConstantFoldPack> {
// Helper pattern - lower tensor.pack operations that pack constants.
struct LowerConstantPacking : public OpRewritePattern<tensor::PackOp> {
using OpRewritePattern<tensor::PackOp>::OpRewritePattern;

// Collect a packed constantOp and its attribute if any.
static FailureOr<std::pair<arith::ConstantOp, DenseElementsAttr>>
getDenseAttributeAndConstant(tensor::PackOp packOp) {
if (packOp.getPaddingValue())
return failure();
Value sourcePack = packOp.getSource();
auto cstOp = sourcePack.getDefiningOp<arith::ConstantOp>();
if (!cstOp)
LogicalResult matchAndRewrite(tensor::PackOp packOp,
PatternRewriter &rewriter) const override {
auto constOp = packOp.getSource().getDefiningOp<arith::ConstantOp>();
if (!constOp)
return failure();
auto cst = cstOp.getValue();
if (!isa<DenseElementsAttr>(cst))
// Must be a dense constant.
auto denseAttr = dyn_cast<DenseElementsAttr>(constOp.getValue());
if (!denseAttr)
return failure();
auto oldDense = cast<DenseElementsAttr>(cst);
return std::make_pair(cstOp, oldDense);
}

static bool areStaticValues(ArrayRef<int64_t> tilesSizes) {
return !llvm::is_contained(tilesSizes, ShapedType::kDynamic);
}

void foldPackIntoCst(RewriterBase &rewriter, tensor::PackOp packOp) {
// Bail out if the user uses pack as a writable operation
// (i.e., the destination is not a tensor.empty).
// Bail out if the pack is used as a writing operation i.e., the destination
// is not a tensor.empty.
if (!packOp.getDest().getDefiningOp<tensor::EmptyOp>())
return;
OpBuilder::InsertionGuard guard(rewriter);
auto cstAndAttribute = getDenseAttributeAndConstant(packOp);
if (failed(cstAndAttribute))
return;
auto [cstOp, oldDense] = *(cstAndAttribute);
// Happy path, splat constant.
if (oldDense.isSplat()) {
auto newDense = oldDense.reshape(packOp.getDestType());
rewriter.setInsertionPoint(cstOp);
rewriter.replaceOpWithNewOp<arith::ConstantOp>(packOp, newDense);
return;
}
LLVM_DEBUG(llvm::dbgs()
<< "NUM ELEMENT: " << oldDense.getNumElements() << "\n");
const int64_t bytes =
oldDense.getRawData().size() / oldDense.getNumElements();

// The original buffer.
ArrayRef<char> rawData = oldDense.getRawData();
// The new buffer.
SmallVector<char> destRawData(rawData.size());

int64_t numberOfElements = oldDense.getNumElements();
SmallVector<int64_t> strides =
computeStrides(packOp.getDestType().getShape());
LLVM_DEBUG(llvm::dbgs() << "#STRIDES: " << strides.size() << "\n";
for (int64_t stride
: strides) llvm::dbgs()
<< stride << " ";
llvm::dbgs() << "\n";);

parallelFor(
packOp.getContext(), 0, numberOfElements,
[&](size_t destLinearizedIdx) {
// Step1. De-linearize destination index.
// f(lin) = tmp[A][B][C]
SmallVector<int64_t> delDestIndexes =
delinearize(destLinearizedIdx, strides);
assert(delDestIndexes.size() ==
static_cast<size_t>(packOp.getDestType().getRank()));

// Step2. Arrange the indexes based on the packing
// information. Step 2.1: Compute inverse of outerDimsPerm to
// bring the loops into the canonical form tmp[A][B][a][b].
if (!packOp.getOuterDimsPerm().empty()) {
SmallVector<int64_t> inversePermutation =
invertPermutationVector(packOp.getOuterDimsPerm());
SmallVector<int64_t> tileLoops;
for (auto i = 0; i < packOp.getSourceType().getRank(); i++)
tileLoops.push_back(delDestIndexes[i]);
applyPermutationToVector(tileLoops, inversePermutation);
SmallVector<int64_t> pointLoops;
for (size_t i = packOp.getSourceType().getRank();
i < delDestIndexes.size(); i++) {
pointLoops.push_back(delDestIndexes[i]);
}
delDestIndexes = tileLoops;
delDestIndexes.append(pointLoops.begin(), pointLoops.end());
assert(delDestIndexes.size() ==
static_cast<size_t>(packOp.getDestType().getRank()));
}
// Step 2.2
// After interchanging the outermost tiled loop we end up in
// the canonical form tmp[A][B][a][b]. Squash the point loops
// with the tiled ones.
llvm::DenseSet<int64_t> tiledLoops(packOp.getInnerDimsPos().begin(),
packOp.getInnerDimsPos().end());
llvm::DenseMap<int64_t, int64_t> mappingTileToPointLoops;
// Map the position of the tiled loops with the point one. Example:
// [A][B] -> [A][B][a][b]
// entry: [A : 0] [a : 2]
// entry: [B : 1] [b : 3]
// [A][B] -> [A][B][b]
// entry: [B : 1] [b : 2]
for (auto tileLoop : llvm::enumerate(packOp.getInnerDimsPos()))
mappingTileToPointLoops[tileLoop.value()] = tileLoop.index();

SmallVector<int64_t> delSourceIndexes;
size_t tilePosIdx = 0;
SmallVector<int64_t> tilesSizes = packOp.getStaticTiles();
if (!areStaticValues(tilesSizes))
return;
int numberOfTileLoops = packOp.getSourceType().getRank();
for (int i = 0; i < numberOfTileLoops; i++) {
// Loop is not tiled.
if (!tiledLoops.count(i)) {
delSourceIndexes.push_back(delDestIndexes[i]);
// Loop is tiled, the point loop is at distance:
// numberOfTileLoops + mappingTileToPointLoops[i].
} else {
delSourceIndexes.push_back(
delDestIndexes[i] * tilesSizes[tilePosIdx] +
delDestIndexes[numberOfTileLoops +
mappingTileToPointLoops[i]]);
tilePosIdx++;
}
}
assert(delSourceIndexes.size() ==
static_cast<size_t>(packOp.getSourceType().getRank()));
int64_t sourceLinearizedIdx =
linearize(delSourceIndexes,
computeStrides(packOp.getSourceType().getShape()));
assert(sourceLinearizedIdx < numberOfElements);
LLVM_DEBUG(llvm::dbgs() << "dest index: " << destLinearizedIdx
<< " map to source index: "
<< sourceLinearizedIdx << "\n");
return rewriter.notifyMatchFailure(packOp,
"expects empty tensor destination");
// Pack destination must have static shape.
if (!packOp.getDestType().hasStaticShape())
return rewriter.notifyMatchFailure(
packOp, "expects destination with static shape");

// Pack with padding is not supported currently.
// TODO: Add tensor.pad folder pattern when available and lower the pack.
if (packOp.getPaddingValue())
return rewriter.notifyMatchFailure(packOp,
"NYI, expects no padding value");

// Step3. Do the packing.
for (int j = 0; j < bytes; j++) {
destRawData[destLinearizedIdx * bytes + j] =
rawData[sourceLinearizedIdx * bytes + j];
}
});
// If it is a splat constant, skip and let tensor.pack folder to handle this
// case.
if (denseAttr.isSplat())
return rewriter.notifyMatchFailure(
packOp, "skip pack - existing folder covers constant splats");

[[maybe_unused]] bool detectSpalt = false;
assert(DenseElementsAttr::isValidRawBuffer(packOp.getDestType(),
destRawData, detectSpalt));
auto newDense =
DenseElementsAttr::getFromRawBuffer(packOp.getDestType(), destRawData);
rewriter.setInsertionPoint(cstOp);
rewriter.replaceOpWithNewOp<arith::ConstantOp>(packOp, newDense);
return linalg::lowerPack(rewriter, packOp);
}
};

void foldPackIntoFill(RewriterBase &rewriter, tensor::PackOp packOp) {
OpBuilder::InsertionGuard guard(rewriter);
Value sourcePack = packOp.getSource();
auto fillOp = sourcePack.getDefiningOp<linalg::FillOp>();
if (!fillOp)
return;
rewriter.setInsertionPoint(packOp);
rewriter.replaceOpWithNewOp<linalg::FillOp>(packOp, fillOp.getInputs()[0],
packOp.getDest());
}
// Rewrite constant packing operation as a compile-time packed constant.
struct ConstantFoldPack
: public tpp::impl::ConstantFoldPackBase<ConstantFoldPack> {

void runOnOperation() override {
auto module = getOperation();
IRRewriter rewriter(&getContext());
module->walk(
[&](tensor::PackOp packOp) { foldPackIntoFill(rewriter, packOp); });
module->walk(
[&](tensor::PackOp packOp) { foldPackIntoCst(rewriter, packOp); });
auto *ctx = &getContext();

// TODO: Add tensor.pad folder pattern when available.
RewritePatternSet patterns(ctx);
// Temporarily lower constant packing operation to allow other existing
// patterns to fold the operation completely.
patterns.add<LowerConstantPacking>(ctx);
// Apply canonicalization to fold trivial cases and linalg constant folders
// to cleanup lowered packs.
linalg::FillOp::getCanonicalizationPatterns(patterns, ctx);
tensor::PackOp::getCanonicalizationPatterns(patterns, ctx);
linalg::populateConstantFoldLinalgOperations(
patterns, [](OpOperand *) -> bool { return true; });

(void)applyPatternsAndFoldGreedily(module, std::move(patterns));
}
};

Expand Down
2 changes: 1 addition & 1 deletion test/Passes/fold-pack-and-constant.mlir
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// RUN: tpp-opt %s -constant-fold-pack -canonicalize -cse -split-input-file | FileCheck %s
// RUN: tpp-opt %s -constant-fold-pack -cse -split-input-file | FileCheck %s

func.func @expect_to_fold_cst() -> tensor<8x2x1x1x32x32xi64> {
%cst = arith.constant dense<1> : tensor<1x1x64x256xi64>
Expand Down
Loading

0 comments on commit f307b5a

Please sign in to comment.