Skip to content

Commit

Permalink
Allow packing matmul dims equal to block size (#961)
Browse files Browse the repository at this point in the history
Relaxes matmul packing blocked dimensions validation to allow packing
when `matmul dims >= block factor`.
  • Loading branch information
adam-smnk authored Sep 4, 2024
1 parent 2dfa9c2 commit 67f9830
Show file tree
Hide file tree
Showing 3 changed files with 34 additions and 7 deletions.
3 changes: 2 additions & 1 deletion lib/TPP/Transforms/ToBlockLayoutAndBack.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -515,7 +515,8 @@ struct PackMatmul : public tpp::impl::PackMatmulBase<PackMatmul> {
size_t posK = 2 + inc;
if (!linalgx::utils::validateFullTilesOnDims(
cast<TilingInterface>(linalgOp.getOperation()),
{tileOnI, tileOnJ, tileOnK}, {posI, posJ, posK})) {
{tileOnI, tileOnJ, tileOnK}, {posI, posJ, posK},
/*minTileFactor=*/1)) {
return std::nullopt;
}

Expand Down
8 changes: 2 additions & 6 deletions test/Passes/DefaultPipeline/default-tpp-passes.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -219,16 +219,12 @@ func.func @softmax(%arg0: tensor<2x2x2x2xf32>, %arg1: tensor<2x2x2x2xf32>) -> te
// CHECK-LABEL: batch_matmul_rewrite
func.func @batch_matmul_rewrite(%arg0: tensor<512x32x64xf32>, %arg1: tensor<512x64x32xf32>) -> tensor<512x32x32xf32> {
%0 = tensor.empty() : tensor<512x32x32xf32>
// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : i64
// CHECK-DAG: %[[C32:.+]] = arith.constant 32 : i64
// CHECK-DAG: %[[C64:.+]] = arith.constant 64 : i64
// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : i64
// CHECK-DAG: %[[C0_i:.+]] = arith.constant 0 : index
// CHECK-DAG: %[[C1_i:.+]] = arith.constant 1 : index
// CHECK-DAG: %[[C512_i:.+]] = arith.constant 512 : index
// CHECK: %{{.+}} = call @xsmm_gemm_dispatch(%[[C1]], %[[C32]], %[[C32]], %[[C64]], %[[C64]], %[[C32]], %[[C32]], %[[C0]])
// CHECK: %{{.+}} = call @xsmm_brgemm_dispatch
// CHECK: scf.parallel{{.*}}(%[[C0_i]]) to (%[[C512_i]]) step (%[[C1_i]])
// CHECK: xsmm_gemm_invoke
// CHECK: xsmm_brgemm_invoke
%1 = linalg.batch_matmul ins(%arg0, %arg1 : tensor<512x32x64xf32>, tensor<512x64x32xf32>)
outs(%0 : tensor<512x32x32xf32>) -> tensor<512x32x32xf32>
return %1 : tensor<512x32x32xf32>
Expand Down
30 changes: 30 additions & 0 deletions test/Passes/pass-matmul-blocking.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,36 @@ func.func @block_linalg_matmul(

// -----

func.func @block_dims_equal_to_factors(
%arg0: tensor<32x32xf32>, %arg1: tensor<32x32xf32>, %arg2: tensor<32x32xf32>)
-> tensor<32x32xf32> {
%0 = linalg.matmul ins(%arg0, %arg1: tensor<32x32xf32>, tensor<32x32xf32>)
outs(%arg2: tensor<32x32xf32>)
-> tensor<32x32xf32>
return %0 : tensor<32x32xf32>
}

// CHECK-DAG: #[[MAP3:.*]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d2, d3, d5)>
// CHECK-DAG: #[[MAP4:.*]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d1, d2, d5, d4)>
// CHECK-DAG: #[[MAP5:.*]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d3, d4)>

// CHECK-LABEL: func @block_dims_equal_to_factors(
// CHECK-SAME: %[[ARG0:[0-9a-z]+]]: tensor<32x32xf32>
// CHECK-SAME: %[[ARG1:[0-9a-z]+]]: tensor<32x32xf32>
// CHECK-SAME: %[[ARG2:[0-9a-z]+]]: tensor<32x32xf32>) -> tensor<32x32xf32> {
// CHECK: %[[BUF0:.+]] = tensor.empty() : tensor<1x1x32x32xf32>
// CHECK: %[[PACK0:.+]] = tensor.pack %[[ARG0]] outer_dims_perm = [0, 1] inner_dims_pos = [0, 1] inner_tiles = [32, 32] into %[[BUF0]] : tensor<32x32xf32> -> tensor<1x1x32x32xf32>
// CHECK: %[[BUF1:.*]] = tensor.empty() : tensor<1x1x32x32xf32>
// CHECK: %[[PACK1:.+]] = tensor.pack %[[ARG1]] outer_dims_perm = [1, 0] inner_dims_pos = [0, 1] inner_tiles = [32, 32] into %[[BUF1]] : tensor<32x32xf32> -> tensor<1x1x32x32xf32>
// CHECK: %[[BUF2:.+]] = tensor.empty() : tensor<1x1x32x32xf32>
// CHECK: %[[PACK2:.+]] = tensor.pack %[[ARG2]] inner_dims_pos = [0, 1] inner_tiles = [32, 32] into %[[BUF2]] : tensor<32x32xf32> -> tensor<1x1x32x32xf32>
// CHECK: %[[VAL:.+]] = linalg.generic {indexing_maps = [#[[MAP3]], #[[MAP4]], #[[MAP5]]], iterator_types = ["parallel", "parallel", "reduction", "parallel", "parallel", "reduction"]} ins(%[[PACK0]], %[[PACK1]] : tensor<1x1x32x32xf32>, tensor<1x1x32x32xf32>) outs(%[[PACK2]] : tensor<1x1x32x32xf32>)
// CHECK: %[[OUT:.+]] = tensor.unpack %[[VAL]] inner_dims_pos = [0, 1] inner_tiles = [32, 32] into %[[ARG2]] : tensor<1x1x32x32xf32> -> tensor<32x32xf32>
// CHECK: return %[[OUT]] : tensor<32x32xf32>
// CHECK: }

// -----

// We don't expect to block as the blocking factor do not create full tiles.
func.func @block_linalg_matmul(
%arg0: tensor<5x6xf32>, %arg1: tensor<6x5xf32>, %arg2: tensor<5x5xf32>)
Expand Down

0 comments on commit 67f9830

Please sign in to comment.