Skip to content

Commit

Permalink
[Codegen] Push up the extract slice op (iree-org#19680)
Browse files Browse the repository at this point in the history
Push the extract_slice ops to the beginning of the block if all its
operands are block arguments. This lets the bufferization framework know
the presense of subset buffer that can be reused.
  • Loading branch information
pashu123 authored Jan 16, 2025
1 parent 08b44e2 commit 36c2353
Show file tree
Hide file tree
Showing 5 changed files with 52 additions and 17 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
#include "mlir/Dialect/Tensor/Utils/Utils.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/Dialect/Vector/Transforms/VectorTransforms.h"
#include "mlir/IR/Dominance.h"
#include "mlir/Interfaces/SubsetOpInterface.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
Expand Down Expand Up @@ -253,10 +254,39 @@ struct CastLikeInsertSliceOpFolder final
};
} // namespace

// Find the earliest insertion point in the block for the given operation.
static Operation *getEarliestInsertionPointInsideBlock(Block *block,
Operation *op) {

Operation *currInsertionPoint = &(*block->getOperations().begin());
DominanceInfo dominanceInfo(currInsertionPoint);

for (auto operand : op->getOperands()) {
if (auto blockArg = dyn_cast<BlockArgument>(operand)) {
continue;
}
Operation *defOp = operand.getDefiningOp();
if (!dominanceInfo.dominates(defOp, currInsertionPoint)) {
currInsertionPoint = defOp;
}
}
return currInsertionPoint;
}

void OptimizeTensorInsertExtractSlicesPass::runOnOperation() {
auto funcOp = getOperation();
IRRewriter rewriter(funcOp->getContext());

// TODO: This is a temporary hack enabled for bufferization to
// get rid of empty buffers.
// Tracked here: https://github.com/llvm/llvm-project/issues/122869
funcOp.walk([&](tensor::ExtractSliceOp extractSliceOp) {
Block *currBlock = extractSliceOp.getOperation()->getBlock();
auto latestInsertionPoint =
getEarliestInsertionPointInsideBlock(currBlock, extractSliceOp);
extractSliceOp->moveAfter(latestInsertionPoint);
});

funcOp.walk([&](scf::ForOp forOp) { moveLoopInvariantCode(forOp); });
LDBG("after hoisting loop invariant code\n" << funcOp);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -321,3 +321,21 @@ func.func @fold_identity_extract_slice(%arg0: tensor<?xf32>) -> tensor<?xf32> {
// CHECK-LABEL: @fold_identity_extract_slice
// CHECK: %[[ARG0:.+]]: tensor<?xf32>
// CHECK: return %[[ARG0]]

// -----

func.func @push_up_extract_slice(%arg0: index, %arg1: vector<64x64xf32>, %arg2: tensor<2x4096x10x64xf16>) -> tensor<1x64x1x64xf16> {
%c0 = arith.constant 0 : index
%0 = tensor.empty() : tensor<64x64xf16>
%c2 = arith.constant 2 : index
%1 = arith.addi %arg0, %c2 : index
%2 = arith.truncf %arg1 : vector<64x64xf32> to vector<64x64xf16>
%3 = vector.transfer_write %2, %0[%c0, %c0] {in_bounds = [true, true]} : vector<64x64xf16>, tensor<64x64xf16>
%extracted_slice = tensor.extract_slice %arg2[%arg0, %c2, %1, %arg0] [1, 64, 1, 64] [1, 1, 1, 1] : tensor<2x4096x10x64xf16> to tensor<1x64x1x64xf16>
%inserted_slice = tensor.insert_slice %3 into %extracted_slice[0, 0, 0, 0] [1, 64, 1, 64] [1, 1, 1, 1] : tensor<64x64xf16> into tensor<1x64x1x64xf16>
return %inserted_slice : tensor<1x64x1x64xf16>
}

// CHECK-LABEL: @push_up_extract_slice
// CHECK: tensor.extract_slice
// CHECK: vector.transfer_write
Original file line number Diff line number Diff line change
Expand Up @@ -142,14 +142,6 @@ hal.executable private @main {
// CHECK: scf.forall ({{.*}}) in (17, 81) {
// CHECK: %[[LOOP:.+]] = scf.for %[[IV:.+]] = %[[C0]] to %[[C721]] step %[[C1]] {{.*}} -> (vector<1x1x1x1x4x1xf32>)
// CHECK: gpu.barrier
// CHECK-DAG: %[[LHS_RD:.+]] = vector.transfer_read %[[B0]]{{.*}} vector<1xf16>
// CHECK-DAG: vector.transfer_write %[[LHS_RD]]
// Note that to simplify the test we are not showing the mapping of the RHS_RD
// to its buffer as it goes through an scf.if/else control structure
// involving allocas.
// CHECK-DAG: %[[RHS_RD:.+]] = vector.transfer_read {{.*}} vector<1xf16>
// CHECK-DAG: vector.transfer_write %[[RHS_RD]]
// CHECK: gpu.barrier
// CHECK-DAG: %[[LHS_MM0:.+]] = vector.transfer_read {{.*}} vector<4xf16>
// CHECK-DAG: %[[RHS_MM:.+]] = vector.transfer_read {{.*}} vector<4x1x1xf16>
// CHECK-COUNT-1: amdgpu.mfma {{.*}}blocks = 1 : i32, k = 16 : i32, m = 16 : i32, n = 16 : i32
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1151,11 +1151,6 @@ hal.executable public @main {
// CHECK: scf.forall ({{.*}}) in (12, 37, 10) {
// CHECK: %[[LOOP:.+]] = scf.for %[[IV:.+]] = %c0 to %c145 step %c1 {{.*}} -> (vector<1x1x1x4x1xf32>)
// CHECK: gpu.barrier
// CHECK-DAG: %[[LHS_RD:.+]] = vector.transfer_read {{.*}} vector<4xf32>
// CHECK-DAG: vector.transfer_write %[[LHS_RD]]
// CHECK-DAG: %[[RHS_RD:.+]] = vector.transfer_read {{.*}} vector<1xf32>
// CHECK-DAG: vector.transfer_write %[[RHS_RD]]
// CHECK: gpu.barrier
// CHECK-DAG: vector.transfer_read {{.*}} #gpu.address_space<workgroup>>, vector<1xf32>
// CHECK-DAG: vector.transfer_read {{.*}} #gpu.address_space<workgroup>>, vector<1xf32>
// CHECK-COUNT-1: amdgpu.mfma {{.*}}blocks = 1 : i32, k = 4 : i32, m = 16 : i32, n = 16 : i32
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -552,16 +552,16 @@ hal.executable.variant @rocm target(<"rocm", "rocm-hsaco-fb">) {
// CHECK-DAG: %[[RHS_GLOBAL:.+]] = hal.interface.binding.subspan layout({{.+}}) binding(1) alignment(64) offset(%c0) flags(ReadOnly) : memref<64x1281x1281xf16, #hal.descriptor_type<storage_buffer>>
// CHECK-DAG: %[[OUT_GLOBAL:.+]] = hal.interface.binding.subspan layout({{.+}}) binding(2) alignment(64) offset(%c0) : memref<64x968x1281xf16, #hal.descriptor_type<storage_buffer>>
// CHECK-DAG: %[[LHS_GLOBAL_SUB:.+]] = memref.subview %[[LHS_GLOBAL]]
// CHECK: %[[LHS_LOAD:.+]] = vector.transfer_read %[[LHS_GLOBAL_SUB]]{{.+}} {in_bounds = [true, false, false]}
// CHECK-DAG: %[[LHS_LOAD:.+]] = vector.transfer_read %[[LHS_GLOBAL_SUB]]{{.+}} {in_bounds = [true, false, false]}
// CHECK-DAG: %[[RHS_GLOBAL_SUB:.+]] = memref.subview %[[RHS_GLOBAL]]
// CHECK: %[[RHS_LOAD:.+]] = vector.transfer_read %[[RHS_GLOBAL_SUB]]{{.+}} {in_bounds = [true, false, false]}
// CHECK-DAG: %[[RHS_LOAD:.+]] = vector.transfer_read %[[RHS_GLOBAL_SUB]]{{.+}} {in_bounds = [true, false, false]}
// CHECK: vector.transfer_write %[[LHS_LOAD]], %[[LHS_SHARED]]
// CHECK: vector.transfer_write %[[RHS_LOAD]], %[[RHS_SHARED]]
// CHECK: %[[RES:.+]] scf.for {{.*}} = %c0 to %c1280 step %c16 iter_args({{.*}}) -> (vector<1x1x1x1x1x1x1x4x1xf16>)
// CHECK-DAG: %[[LHS_GLOBAL_SUB:.+]] = memref.subview %[[LHS_GLOBAL]]
// CHECK: %[[LHS_LOAD:.+]] = vector.transfer_read %[[LHS_GLOBAL_SUB]]
// CHECK-DAG: %[[LHS_LOAD:.+]] = vector.transfer_read %[[LHS_GLOBAL_SUB]]
// CHECK-DAG: %[[RHS_GLOBAL_SUB:.+]] = memref.subview %[[RHS_GLOBAL]]
// CHECK: %[[RHS_LOAD:.+]] = vector.transfer_read %[[RHS_GLOBAL_SUB]]{{.+}} {in_bounds = [true, false, false]}
// CHECK-DAG: %[[RHS_LOAD:.+]] = vector.transfer_read %[[RHS_GLOBAL_SUB]]{{.+}} {in_bounds = [true, false, false]}
// CHECK: gpu.barrier
// CHECK-DAG: %{{.+}} = vector.transfer_read %[[LHS_SHARED]]
// CHECK-DAG: %{{.+}} = vector.transfer_read %[[RHS_SHARED]]
Expand Down

0 comments on commit 36c2353

Please sign in to comment.