Skip to content

Commit

Permalink
Switch to scf::tileUsingSCF (#920)
Browse files Browse the repository at this point in the history
Deprecates TPP usage of linalg::tileToForallOpUsingTileSizes in
preparation for upstream Linalg API deprecation.

The corresponding test is updated as SCF tiling API folds affine maps as
part of tiling resulting in simpler IR.

Fixes #676
  • Loading branch information
adam-smnk authored Jun 5, 2024
1 parent f7b38c4 commit 46eff69
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 22 deletions.
11 changes: 7 additions & 4 deletions lib/TPP/Transforms/RewriteBatchMatmulToMatmul.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -103,12 +103,15 @@ struct RewriteBatchMatmulToMatmul
tiles[0] = getAsIndexOpFoldResult(rewriter.getContext(), 1);
OpBuilder::InsertionGuard guard(rewriter);
rewriter.setInsertionPoint(batchMatmulOp);
auto tilingResult = linalg::tileToForallOpUsingTileSizes(
rewriter, cast<TilingInterface>(batchMatmulOp.getOperation()), tiles,
/*mapping=*/std::nullopt);
scf::SCFTilingOptions tilingOpts;
tilingOpts.setTileSizes(tiles);
tilingOpts.setLoopType(scf::SCFTilingOptions::LoopType::ForallOp);
auto tilingResult = scf::tileUsingSCF(
rewriter, cast<TilingInterface>(batchMatmulOp.getOperation()),
tilingOpts);
if (failed(tilingResult))
return signalPassFailure();
rewriter.replaceOp(batchMatmulOp, tilingResult->tileOp->getResults());
rewriter.replaceOp(batchMatmulOp, tilingResult->replacements);
});

// Step2:
Expand Down
31 changes: 13 additions & 18 deletions test/Passes/pass-rewrite-batch-matmul-to-matmul.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -53,17 +53,14 @@ func.func @batch_matmul_rewrite(%arg0: tensor<512x?x?xf32>,

// -----

// TODO: tiling using scf.forall introduces the affine.min that prevents
// rank reducing the tensor and map to brgemm. See: #676
func.func @batch_matmul_rewrite(%arg0: tensor<?x?x?xf32>,
%arg1: tensor<?x?x?xf32>, %dim0: index, %dim1: index, %bacth: index) -> tensor<?x?x?xf32> {
%0 = tensor.empty(%bacth, %dim0, %dim1) : tensor<?x?x?xf32>
%arg1: tensor<?x?x?xf32>, %dim0: index, %dim1: index, %batch: index) -> tensor<?x?x?xf32> {
%0 = tensor.empty(%batch, %dim0, %dim1) : tensor<?x?x?xf32>
%1 = linalg.batch_matmul ins(%arg0, %arg1 : tensor<?x?x?xf32>, tensor<?x?x?xf32>)
outs(%0 : tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
return %1 : tensor<?x?x?xf32>
}

// CHECK: #[[MAP:.+]] = affine_map<(d0)[s0] -> (-d0 + s0, 1)>
// CHECK-LABEL: batch_matmul_rewrite
// CHECK-SAME: %[[ARG0:.+]]: tensor<?x?x?xf32>, %[[ARG1:.+]]: tensor<?x?x?xf32>,
// CHECK-SAME: %[[ARG2:.+]]: index, %[[ARG3:.+]]: index, %[[ARG4:.+]]: index
Expand All @@ -72,18 +69,16 @@ func.func @batch_matmul_rewrite(%arg0: tensor<?x?x?xf32>,
// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
// CHECK: %[[EMPTY:.+]] = tensor.empty(%[[ARG4]], %[[ARG2]], %[[ARG3]]) : tensor<?x?x?xf32>
// CHECK: %[[DIM:.+]] = tensor.dim %[[ARG0]], %[[C0]] : tensor<?x?x?xf32>
// CHECK: %[[DIM0:.+]] = tensor.dim %[[ARG0]], %[[C0]] : tensor<?x?x?xf32>
// CHECK: %[[DIM1:.+]] = tensor.dim %[[ARG0]], %[[C1]] : tensor<?x?x?xf32>
// CHECK: %[[DIM2:.+]] = tensor.dim %[[ARG0]], %[[C2]] : tensor<?x?x?xf32>
// CHECK: %[[DIM3:.+]] = tensor.dim %[[ARG1]], %[[C2]] : tensor<?x?x?xf32>
// CHECK: %[[DIM0:.+]] = tensor.dim %[[ARG0]], %[[C1]] : tensor<?x?x?xf32>
// CHECK: %[[DIM1:.+]] = tensor.dim %[[ARG0]], %[[C2]] : tensor<?x?x?xf32>
// CHECK: %[[DIM2:.+]] = tensor.dim %[[ARG1]], %[[C2]] : tensor<?x?x?xf32>
// CHECK: %{{.+}} = scf.forall (%[[ARG5:.+]]) in (%[[DIM]])
// CHECK-SAME: shared_outs(%[[ARG6:.+]] = %[[EMPTY]]) -> (tensor<?x?x?xf32>) {
// CHECK: %[[MIN:.+]] = affine.min #[[MAP]](%[[ARG5]])[%[[DIM0]]]
// CHECK: %[[SLICE:.+]] = tensor.extract_slice %[[ARG0]][%[[ARG5]], 0, 0] [%[[MIN]], %[[DIM1]], %[[DIM2]]] [1, 1, 1]
// CHECK-SAME: : tensor<?x?x?xf32> to tensor<?x?x?xf32>
// CHECK: %[[SLICE1:.+]] = tensor.extract_slice %[[ARG1]][%[[ARG5]], 0, 0] [%[[MIN]], %[[DIM2]], %[[DIM3]]] [1, 1, 1]
// CHECK-SAME: : tensor<?x?x?xf32> to tensor<?x?x?xf32>
// CHECK: %[[SLICE2:.+]] = tensor.extract_slice %[[ARG6]][%[[ARG5]], 0, 0] [%[[MIN]], %[[DIM1]], %[[DIM3]]] [1, 1, 1]
// CHECK-SAME: : tensor<?x?x?xf32> to tensor<?x?x?xf32>
// CHECK: %{{.+}} = linalg.batch_matmul ins(%[[SLICE]], %[[SLICE1]] : tensor<?x?x?xf32>, tensor<?x?x?xf32>)
// CHECK-SAME: outs(%[[SLICE2]] : tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
// CHECK: %[[SLICE:.+]] = tensor.extract_slice %[[ARG0]][%[[ARG5]], 0, 0] [1, %[[DIM0]], %[[DIM1]]] [1, 1, 1]
// CHECK-SAME: : tensor<?x?x?xf32> to tensor<?x?xf32>
// CHECK: %[[SLICE1:.+]] = tensor.extract_slice %[[ARG1]][%[[ARG5]], 0, 0] [1, %[[DIM1]], %[[DIM2]]] [1, 1, 1]
// CHECK-SAME: : tensor<?x?x?xf32> to tensor<?x?xf32>
// CHECK: %[[SLICE2:.+]] = tensor.extract_slice %[[ARG6]][%[[ARG5]], 0, 0] [1, %[[DIM0]], %[[DIM2]]] [1, 1, 1]
// CHECK-SAME: : tensor<?x?x?xf32> to tensor<?x?xf32>
// CHECK: %{{.+}} = linalg.matmul ins(%[[SLICE]], %[[SLICE1]] : tensor<?x?xf32>, tensor<?x?xf32>)
// CHECK-SAME: outs(%[[SLICE2]] : tensor<?x?xf32>) -> tensor<?x?xf32>

0 comments on commit 46eff69

Please sign in to comment.