From 70eb0e37a86747f9266e4c8380baa89746f5e23b Mon Sep 17 00:00:00 2001 From: Rob Suderman Date: Mon, 29 Jan 2024 20:32:15 -0800 Subject: [PATCH] [mlir][tensor] Fix `tensor.pad` to remove newly static values (#79938) The canonicalization incrementally converts foldable dynamic hi/lo padding to static hi/lo values. During this canonicalization the static-fied valued should be removed from the dynamic values. --- mlir/lib/Dialect/Tensor/IR/TensorOps.cpp | 6 +++++- mlir/test/Dialect/Tensor/canonicalize.mlir | 2 +- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp index b2fe58099b2fb3..b21e89ae3a5713 100644 --- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp +++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp @@ -3158,19 +3158,23 @@ struct FoldStaticPadding : public OpRewritePattern { // Extract the static info from the high and low operands. SmallVector constOperandsLow; + SmallVector newLows; for (auto operand : padTensorOp.getLow()) { APSInt intOp; if (!matchPattern(operand, m_ConstantInt(&intOp))) { constOperandsLow.push_back(ShapedType::kDynamic); + newLows.push_back(operand); continue; } constOperandsLow.push_back(intOp.getExtValue()); } SmallVector constOperandsHigh; + SmallVector newHighs; for (auto operand : padTensorOp.getHigh()) { APSInt intOp; if (!matchPattern(operand, m_ConstantInt(&intOp))) { constOperandsHigh.push_back(ShapedType::kDynamic); + newHighs.push_back(operand); continue; } constOperandsHigh.push_back(intOp.getExtValue()); @@ -3222,7 +3226,7 @@ struct FoldStaticPadding : public OpRewritePattern { newOutDims, padTensorOp.getType().getElementType()); auto newOp = rewriter.create( padTensorOp->getLoc(), newResultType, input, staticLow, staticHigh, - padTensorOp.getLow(), padTensorOp.getHigh(), padTensorOp.getNofold(), + newLows, newHighs, padTensorOp.getNofold(), getPrunedAttributeList(padTensorOp, PadOp::getAttributeNames())); IRMapping mapper; diff --git a/mlir/test/Dialect/Tensor/canonicalize.mlir b/mlir/test/Dialect/Tensor/canonicalize.mlir index ed964071358ace..7192a719ceb13d 100644 --- a/mlir/test/Dialect/Tensor/canonicalize.mlir +++ b/mlir/test/Dialect/Tensor/canonicalize.mlir @@ -1361,7 +1361,7 @@ func.func @pad_same_static_shape(%arg0: tensor<5x6xf32>, %a: index) // CHECK-LABEL: func @pad_fold_static( // CHECK-SAME: %[[INPUT:.*]]: tensor) -> tensor { // CHECK: %[[CST:.*]] = arith.constant 0.000000e+00 : f32 -// CHECK: %[[PADDING:.*]] = arith.constant 4 : index +// CHECK-NOT: arith.constant 4 : index // CHECK: %[[PADDED:.*]] = tensor.pad %[[INPUT]] // CHECK-SAME: low[0, 4, 1, 1] high[0, 4, 1, 1] { // CHECK: ^bb0(%[[ARG1:.*]]: index, %[[ARG2:.*]]: index, %[[ARG3:.*]]: index, %[[ARG4:.*]]: index):