From 80695c9d2295e9042d1716d9007db4623a9b1c0f Mon Sep 17 00:00:00 2001 From: Adam Siemieniuk Date: Wed, 4 Sep 2024 18:01:40 +0200 Subject: [PATCH 1/3] Pack small dims --- lib/TPP/Transforms/ToBlockLayoutAndBack.cpp | 37 ++++++++++++++++++++- 1 file changed, 36 insertions(+), 1 deletion(-) diff --git a/lib/TPP/Transforms/ToBlockLayoutAndBack.cpp b/lib/TPP/Transforms/ToBlockLayoutAndBack.cpp index aac5ea307..40a1576d8 100644 --- a/lib/TPP/Transforms/ToBlockLayoutAndBack.cpp +++ b/lib/TPP/Transforms/ToBlockLayoutAndBack.cpp @@ -274,6 +274,20 @@ packConvolutions(RewriterBase &rewriter, OpTy convOp, return replacementOp; } +/// Return constant range span or nullopt, otherwise. +static std::optional getConstantRange(const Range &range) { + std::optional stride = getConstantIntValue(range.stride); + if (!stride || *stride != 1) + return std::nullopt; + std::optional offset = getConstantIntValue(range.offset); + if (!offset) + return std::nullopt; + std::optional size = getConstantIntValue(range.size); + if (!size) + return std::nullopt; + return (*size - *offset); +} + //===----------------------------------------------------------------------===// // Conv2DNhwcHwcfOp //===----------------------------------------------------------------------===// @@ -478,6 +492,7 @@ struct PackMatmul : public tpp::impl::PackMatmulBase { MLIRContext *ctx = getOperation().getContext(); RewritePatternSet patterns(ctx); + // TODO: Add a cost function that decides whether to pack at all. auto packControlFn = [&](linalg::LinalgOp linalgOp) -> std::optional { linalg::BlockPackMatmulOptions options; @@ -501,8 +516,28 @@ struct PackMatmul : public tpp::impl::PackMatmulBase { // Allow padding to avoid double checks. options.allowPadding = true; - // Apply more restrictive packing validation. + // Adjust block factors to smaller dimensions. + // If a dimension is smaller than the blocking factor, then + // try to block by the dimension size. + auto dims = linalg::inferContractionDims(linalgOp); + if (failed(dims)) + return std::nullopt; + OpBuilder builder(linalgOp); + auto tileOp = cast(linalgOp.getOperation()); + SmallVector iterationDomain = tileOp.getIterationDomain(builder); + + if (std::optional dimM = + getConstantRange(iterationDomain[dims->m.back()])) + options.blockFactors[0] = std::min(*dimM, options.blockFactors[0]); + if (std::optional dimN = + getConstantRange(iterationDomain[dims->n.back()])) + options.blockFactors[1] = std::min(*dimN, options.blockFactors[1]); + if (std::optional dimK = + getConstantRange(iterationDomain[dims->k.back()])) + options.blockFactors[2] = std::min(*dimK, options.blockFactors[2]); + + // Apply more restrictive packing validation. SmallVector tiles = getAsOpFoldResult(builder.getI64ArrayAttr(options.blockFactors)); OpFoldResult tileOnI = tiles[0]; From 97f05a19d141bb07c1610df34a29ef1fa3607b92 Mon Sep 17 00:00:00 2001 From: Adam Siemieniuk Date: Wed, 4 Sep 2024 18:57:48 +0200 Subject: [PATCH 2/3] Refactor util --- include/TPP/Transforms/Utils/TransformUtils.h | 3 +++ lib/TPP/Transforms/ToBlockLayoutAndBack.cpp | 20 +++---------------- lib/TPP/Transforms/TransformUtils.cpp | 2 +- 3 files changed, 7 insertions(+), 18 deletions(-) diff --git a/include/TPP/Transforms/Utils/TransformUtils.h b/include/TPP/Transforms/Utils/TransformUtils.h index 1034cf3ac..f1e26b669 100644 --- a/include/TPP/Transforms/Utils/TransformUtils.h +++ b/include/TPP/Transforms/Utils/TransformUtils.h @@ -75,6 +75,9 @@ bool isBlockedMatmul(Operation *op); FailureOr isContraction(linalg::LinalgOp linalgOp); +// Return constant range span or nullopt, otherwise. +std::optional getConstantRange(const Range &range); + // Validate a tile configuration for a linalgOp when we can statically do that. // Specific dims can be passed using 'dims'. If dims is empty the validation // will start from the outermost dimension, moving to innermost ones up to the diff --git a/lib/TPP/Transforms/ToBlockLayoutAndBack.cpp b/lib/TPP/Transforms/ToBlockLayoutAndBack.cpp index 40a1576d8..95b1da5d9 100644 --- a/lib/TPP/Transforms/ToBlockLayoutAndBack.cpp +++ b/lib/TPP/Transforms/ToBlockLayoutAndBack.cpp @@ -274,20 +274,6 @@ packConvolutions(RewriterBase &rewriter, OpTy convOp, return replacementOp; } -/// Return constant range span or nullopt, otherwise. -static std::optional getConstantRange(const Range &range) { - std::optional stride = getConstantIntValue(range.stride); - if (!stride || *stride != 1) - return std::nullopt; - std::optional offset = getConstantIntValue(range.offset); - if (!offset) - return std::nullopt; - std::optional size = getConstantIntValue(range.size); - if (!size) - return std::nullopt; - return (*size - *offset); -} - //===----------------------------------------------------------------------===// // Conv2DNhwcHwcfOp //===----------------------------------------------------------------------===// @@ -528,13 +514,13 @@ struct PackMatmul : public tpp::impl::PackMatmulBase { SmallVector iterationDomain = tileOp.getIterationDomain(builder); if (std::optional dimM = - getConstantRange(iterationDomain[dims->m.back()])) + linalgx::utils::getConstantRange(iterationDomain[dims->m.back()])) options.blockFactors[0] = std::min(*dimM, options.blockFactors[0]); if (std::optional dimN = - getConstantRange(iterationDomain[dims->n.back()])) + linalgx::utils::getConstantRange(iterationDomain[dims->n.back()])) options.blockFactors[1] = std::min(*dimN, options.blockFactors[1]); if (std::optional dimK = - getConstantRange(iterationDomain[dims->k.back()])) + linalgx::utils::getConstantRange(iterationDomain[dims->k.back()])) options.blockFactors[2] = std::min(*dimK, options.blockFactors[2]); // Apply more restrictive packing validation. diff --git a/lib/TPP/Transforms/TransformUtils.cpp b/lib/TPP/Transforms/TransformUtils.cpp index 22e17d3ae..8e3a75b10 100644 --- a/lib/TPP/Transforms/TransformUtils.cpp +++ b/lib/TPP/Transforms/TransformUtils.cpp @@ -281,7 +281,7 @@ isContraction(linalg::LinalgOp linalgOp) { return dims; } -static std::optional getConstantRange(const Range &range) { +std::optional getConstantRange(const Range &range) { std::optional stride = getConstantIntValue(range.stride); if (!stride || *stride != 1) return std::nullopt; From 669872dd579b4b8283253b2000286962f6b3b37e Mon Sep 17 00:00:00 2001 From: Adam Siemieniuk Date: Thu, 5 Sep 2024 13:03:37 +0200 Subject: [PATCH 3/3] Adjust tests --- .../DefaultPipeline/default-tpp-passes.mlir | 53 ++----------------- .../DefaultPipeline/linalg-to-xsmm.mlir | 7 ++- test/Passes/pass-matmul-blocking-default.mlir | 36 +++++++++++++ test/Passes/pass-matmul-blocking.mlir | 43 +++++++-------- test/Passes/tpp-mapping.mlir | 5 +- 5 files changed, 67 insertions(+), 77 deletions(-) diff --git a/test/Passes/DefaultPipeline/default-tpp-passes.mlir b/test/Passes/DefaultPipeline/default-tpp-passes.mlir index 0b6590be1..ac1e78228 100644 --- a/test/Passes/DefaultPipeline/default-tpp-passes.mlir +++ b/test/Passes/DefaultPipeline/default-tpp-passes.mlir @@ -8,15 +8,15 @@ func.func @matmul(%A: tensor<4x8xf32>, %B: tensor<8x4xf32>, %C: tensor<4x4xf32>) -> tensor<4x4xf32> { // CHECK: %[[C0:.+]] = arith.constant 0 : index // CHECK: call @xsmm_gemm_dispatch - // CHECK: %[[ptr0:.*]] = memref.extract_aligned_pointer_as_index %[[ARG0]] + // CHECK: %[[ptr0:.*]] = memref.extract_aligned_pointer_as_index // CHECK-NEXT: %[[cast_ptr0:.*]] = arith.index_cast %[[ptr0]] : index to i64 // CHECK-NEXT: %[[llvm_ptr0:.*]] = llvm.inttoptr %[[cast_ptr0]] : i64 to !llvm.ptr - // CHECK: %[[ptr1:.*]] = memref.extract_aligned_pointer_as_index %[[ARG1]] + // CHECK: %[[ptr1:.*]] = memref.extract_aligned_pointer_as_index // CHECK-NEXT: %[[cast_ptr1:.*]] = arith.index_cast %[[ptr1]] : index to i64 // CHECK-NEXT: %[[llvm_ptr1:.*]] = llvm.inttoptr %[[cast_ptr1]] : i64 to !llvm.ptr - // CHECK: %[[ptr2:.*]] = memref.extract_aligned_pointer_as_index %[[ARG2]] + // CHECK: %[[ptr2:.*]] = memref.extract_aligned_pointer_as_index // CHECK-NEXT: %[[cast_ptr2:.*]] = arith.index_cast %[[ptr2]] : index to i64 // CHECK-NEXT: %[[llvm_ptr2:.*]] = llvm.inttoptr %[[cast_ptr2]] : i64 to !llvm.ptr @@ -90,53 +90,6 @@ func.func @conv2d_1x1( // ----- -#map = affine_map<(d0, d1) -> (d0 + d1)> - -// CHECK-LABEL: @conv2d_1x1_decomposed( -// CHECK-SAME: %[[arg:.*]]: memref<1x7x7x2048xf32>) -> memref<1x7x7x512xf32> { -func.func @conv2d_1x1_decomposed( - %arg0 : tensor<1x7x7x2048xf32>) -> tensor<1x7x7x512xf32> { - %c0 = arith.constant 0 : index - %c1 = arith.constant 1 : index - %c7 = arith.constant 7 : index - - // Conv2D weights - %cst = arith.constant dense<0.00332225906> : tensor<2048x512xf32> - - // 1x1 Conv2D - // CHECK: call @xsmm_gemm_dispatch - // CHECK: scf.for - // CHECK: %[[ptr0:.*]] = llvm.inttoptr %{{.+}} : i64 to !llvm.ptr - // CHECK: %[[ptr1:.*]] = llvm.inttoptr %{{.+}} : i64 to !llvm.ptr - // CHECK: %[[ptr2:.*]] = llvm.inttoptr %{{.+}} : i64 to !llvm.ptr - // CHECK: call @xsmm_gemm_invoke({{.*}}%[[ptr0]], %{{.+}}, %[[ptr1]], %{{.+}}, %[[ptr2]], %{{.+}} - %cst_0 = arith.constant 0.000000e+00 : f32 - %0 = tensor.empty() : tensor<1x7x7x512xf32> - %1 = linalg.fill ins(%cst_0 : f32) outs(%0 : tensor<1x7x7x512xf32>) -> tensor<1x7x7x512xf32> - %2 = scf.for %arg1 = %c0 to %c1 step %c1 iter_args(%arg2 = %1) -> (tensor<1x7x7x512xf32>) { - %3 = scf.for %arg3 = %c0 to %c7 step %c1 iter_args(%arg4 = %arg2) -> (tensor<1x7x7x512xf32>) { - %4 = scf.for %arg5 = %c0 to %c1 step %c1 iter_args(%arg6 = %arg4) -> (tensor<1x7x7x512xf32>) { - %5 = scf.for %arg7 = %c0 to %c1 step %c1 iter_args(%arg8 = %arg6) -> (tensor<1x7x7x512xf32>) { - %6 = affine.apply #map(%arg3, %arg5) - %extracted_slice = tensor.extract_slice %arg0[%arg1, %6, %arg7, 0] [1, 1, 7, 2048] [1, 1, 1, 1] : tensor<1x7x7x2048xf32> to tensor<7x2048xf32> - %extracted_slice_1 = tensor.extract_slice %arg8[%arg1, %arg3, 0, 0] [1, 1, 7, 512] [1, 1, 1, 1] : tensor<1x7x7x512xf32> to tensor<7x512xf32> - %7 = linalg.matmul ins(%extracted_slice, %cst : tensor<7x2048xf32>, tensor<2048x512xf32>) outs(%extracted_slice_1 : tensor<7x512xf32>) -> tensor<7x512xf32> - %inserted_slice = tensor.insert_slice %7 into %arg8[%arg1, %arg3, 0, 0] [1, 1, 7, 512] [1, 1, 1, 1] : tensor<7x512xf32> into tensor<1x7x7x512xf32> - scf.yield %inserted_slice : tensor<1x7x7x512xf32> - } - scf.yield %5 : tensor<1x7x7x512xf32> - } - scf.yield %4 : tensor<1x7x7x512xf32> - } - scf.yield %3 : tensor<1x7x7x512xf32> - } - - // CHECK: return {{.*}} : memref<1x7x7x512xf32> - return %2 : tensor<1x7x7x512xf32> -} - -// ----- - #map0 = affine_map<(d0, d1) -> (d1)> #map1 = affine_map<(d0, d1) -> (d0, d1)> #map2 = affine_map<(d0, d1, d2) -> (d0, d2)> diff --git a/test/Passes/DefaultPipeline/linalg-to-xsmm.mlir b/test/Passes/DefaultPipeline/linalg-to-xsmm.mlir index deb7f6789..2b76d1ebf 100644 --- a/test/Passes/DefaultPipeline/linalg-to-xsmm.mlir +++ b/test/Passes/DefaultPipeline/linalg-to-xsmm.mlir @@ -50,15 +50,14 @@ func.func @gemm_with_zero(%arg0: tensor<3x3xf32>, %arg1: tensor<3x3xf32>) -> ten // CHECK-DAG: %[[C3:.+]] = arith.constant 3 : i64 // CHECK-DAG: %[[C4:.+]] = arith.constant 4 : i64 // CHECK-NOT: xsmm_unary_dispatch -// CHECK: %[[ALLOC:.+]] = memref.alloc() {alignment = 64 : i64} : memref<3x3xf32> // CHECK: %[[DIS:.+]] = call @xsmm_gemm_dispatch(%[[C1]], %[[C3]], %[[C3]], %[[C3]], %[[C3]], %[[C3]], %[[C3]], %[[C4]]) -// CHECK: %[[INT_PTR_ARG0:.+]] = memref.extract_aligned_pointer_as_index %[[ARG0]] : memref<3x3xf32> -> index +// CHECK: %[[INT_PTR_ARG0:.+]] = memref.extract_aligned_pointer_as_index // CHECK: %[[CAST_ARG0:.+]] = arith.index_cast %[[INT_PTR_ARG0]] : index to i64 // CHECK: %[[LLVM_PTR_ARG0:.+]] = llvm.inttoptr %[[CAST_ARG0]] : i64 to !llvm.ptr -// CHECK: %[[INT_PTR_ARG1:.+]] = memref.extract_aligned_pointer_as_index %[[ARG1]] : memref<3x3xf32> -> index +// CHECK: %[[INT_PTR_ARG1:.+]] = memref.extract_aligned_pointer_as_index // CHECK: %[[CAST_ARG1:.+]] = arith.index_cast %[[INT_PTR_ARG1]] : index to i64 // CHECK: %[[LLVM_PTR_ARG1:.+]] = llvm.inttoptr %[[CAST_ARG1]] : i64 to !llvm.ptr -// CHECK: %[[INT_PTR_ALLOC:.+]] = memref.extract_aligned_pointer_as_index %[[ALLOC]] : memref<3x3xf32> -> index +// CHECK: %[[INT_PTR_ALLOC:.+]] = memref.extract_aligned_pointer_as_index // CHECK: %[[CAST_ALLOC:.+]] = arith.index_cast %[[INT_PTR_ALLOC]] : index to i64 // CHECK: %[[LLVM_PTR_ALLOC:.+]] = llvm.inttoptr %[[CAST_ALLOC]] : i64 to !llvm.ptr // CHECK: call @xsmm_gemm_invoke(%[[C1]], %[[DIS]], %[[LLVM_PTR_ARG0]], %[[C0]], %[[LLVM_PTR_ARG1]], %[[C0]], %[[LLVM_PTR_ALLOC]], %[[C0]]) diff --git a/test/Passes/pass-matmul-blocking-default.mlir b/test/Passes/pass-matmul-blocking-default.mlir index 61cbdf86c..8e00f1306 100644 --- a/test/Passes/pass-matmul-blocking-default.mlir +++ b/test/Passes/pass-matmul-blocking-default.mlir @@ -84,3 +84,39 @@ func.func @block_linalg_matmul_transpose_b( // CHECK: %[[VAL:.+]] = linalg.generic {indexing_maps = [#[[MAP3]], #[[MAP4]], #[[MAP5]]], iterator_types = ["parallel", "parallel", "reduction", "parallel", "parallel", "reduction"]} ins(%[[PACK0]], %[[PACK1]] : tensor<4x4x32x32xf32>, tensor<4x4x32x32xf32>) outs(%[[PACK2]] : tensor<4x4x32x32xf32>) // CHECK: %[[OUT:.+]] = tensor.unpack %[[VAL]] inner_dims_pos = [0, 1] inner_tiles = [32, 32] into %[[ARG2]] : tensor<4x4x32x32xf32> -> tensor<128x128xf32> // CHECK: return %[[OUT]] : tensor<128x128xf32> + +// ----- + +func.func @block_linalg_matmul_dynamic( + %arg0: tensor, %arg1: tensor, %arg2: tensor) + -> tensor { + %0 = linalg.matmul ins(%arg0, %arg1: tensor, tensor) + outs(%arg2: tensor) + -> tensor + return %0 : tensor +} + +// CHECK-DAG: #[[MAP3:.*]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d2, d3, d5)> +// CHECK-DAG: #[[MAP4:.*]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d1, d2, d5, d4)> +// CHECK-DAG: #[[MAP5:.*]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d3, d4)> + +// CHECK-LABEL: func @block_linalg_matmul_dynamic( +// CHECK-SAME: %[[ARG0:[0-9a-z]+]]: tensor +// CHECK-SAME: %[[ARG1:[0-9a-z]+]]: tensor +// CHECK-SAME: %[[ARG2:[0-9a-z]+]]: tensor) -> tensor { +// CHECK-DAG: %[[PAD:.+]] = arith.constant 0.000000e+00 : f32 +// CHECK: %[[PACK0:.+]] = tensor.pack %[[ARG0]] padding_value(%[[PAD]] : f32) +// CHECK-SAME: outer_dims_perm = [0, 1] inner_dims_pos = [0, 1] +// CHECK-SAME: inner_tiles = [32, 32] into {{.*}} : tensor -> tensor +// CHECK: %[[PACK1:.+]] = tensor.pack %[[ARG1]] padding_value(%[[PAD]] : f32) +// CHECK-SAME: outer_dims_perm = [1, 0] inner_dims_pos = [0, 1] +// CHECK-SAME: inner_tiles = [32, 32] into {{.*}} : tensor -> tensor +// CHECK: %[[PACK2:.+]] = tensor.pack %[[ARG2]] padding_value(%[[PAD]] : f32) +// CHECK-SAME: inner_dims_pos = [0, 1] inner_tiles = [32, 32] +// CHECK-SAME: into {{.*}} : tensor -> tensor +// CHECK: %[[VAL:.+]] = linalg.generic {indexing_maps = [#[[MAP3]], #[[MAP4]], #[[MAP5]]], +// CHECK-SAME: iterator_types = ["parallel", "parallel", "reduction", "parallel", "parallel", "reduction"]} +// CHECK-SAME: ins(%[[PACK0]], %[[PACK1]] : tensor, tensor) outs(%[[PACK2]] : tensor) +// CHECK: %[[OUT:.+]] = tensor.unpack %[[VAL]] inner_dims_pos = [0, 1] inner_tiles = [32, 32] +// CHECK-SAME: into %[[ARG2]] : tensor -> tensor +// CHECK: return %[[OUT]] : tensor diff --git a/test/Passes/pass-matmul-blocking.mlir b/test/Passes/pass-matmul-blocking.mlir index 3168ad016..830d38916 100644 --- a/test/Passes/pass-matmul-blocking.mlir +++ b/test/Passes/pass-matmul-blocking.mlir @@ -60,8 +60,10 @@ func.func @block_dims_equal_to_factors( // ----- -// We don't expect to block as the blocking factor do not create full tiles. -func.func @block_linalg_matmul( +// Adapt tile sizes to small dimensions. +// Assume that there is separate cost function that controls +// if packing should take place at all. +func.func @block_small_dims_matmul( %arg0: tensor<5x6xf32>, %arg1: tensor<6x5xf32>, %arg2: tensor<5x5xf32>) -> tensor<5x5xf32> { %0 = linalg.matmul ins(%arg0, %arg1: tensor<5x6xf32>, tensor<6x5xf32>) @@ -70,13 +72,24 @@ func.func @block_linalg_matmul( return %0 : tensor<5x5xf32> } -// CHECK-LABEL: func.func @block_linalg_matmul( -// CHECK-SAME: %[[ARG0:[0-9a-z]+]]: tensor<5x6xf32>, -// CHECK-SAME: %[[ARG1:[0-9a-z]+]]: tensor<6x5xf32>, -// CHECK-SAME: %[[ARG2:[0-9a-z]+]]: tensor<5x5xf32>) -> tensor<5x5xf32> { -// CHECK: %{{.+}} = linalg.matmul -// CHECK-SAME: ins(%[[ARG0]], %[[ARG1]] -// CHECK-SAME: outs(%[[ARG2]] +// CHECK-DAG: #[[MAP3:.*]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d2, d3, d5)> +// CHECK-DAG: #[[MAP4:.*]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d1, d2, d5, d4)> +// CHECK-DAG: #[[MAP5:.*]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d3, d4)> + +// CHECK-LABEL: func @block_small_dims_matmul( +// CHECK-SAME: %[[ARG0:[0-9a-z]+]]: tensor<5x6xf32> +// CHECK-SAME: %[[ARG1:[0-9a-z]+]]: tensor<6x5xf32> +// CHECK-SAME: %[[ARG2:[0-9a-z]+]]: tensor<5x5xf32>) -> tensor<5x5xf32> { +// CHECK: %[[BUF0:.+]] = tensor.empty() : tensor<1x1x5x6xf32> +// CHECK: %[[PACK0:.+]] = tensor.pack %[[ARG0]] outer_dims_perm = [0, 1] inner_dims_pos = [0, 1] inner_tiles = [5, 6] into %[[BUF0]] : tensor<5x6xf32> -> tensor<1x1x5x6xf32> +// CHECK: %[[BUF1:.*]] = tensor.empty() : tensor<1x1x6x5xf32> +// CHECK: %[[PACK1:.+]] = tensor.pack %[[ARG1]] outer_dims_perm = [1, 0] inner_dims_pos = [0, 1] inner_tiles = [6, 5] into %[[BUF1]] : tensor<6x5xf32> -> tensor<1x1x6x5xf32> +// CHECK: %[[BUF2:.+]] = tensor.empty() : tensor<1x1x5x5xf32> +// CHECK: %[[PACK2:.+]] = tensor.pack %[[ARG2]] inner_dims_pos = [0, 1] inner_tiles = [5, 5] into %[[BUF2]] : tensor<5x5xf32> -> tensor<1x1x5x5xf32> +// CHECK: %[[VAL:.+]] = linalg.generic {indexing_maps = [#[[MAP3]], #[[MAP4]], #[[MAP5]]], iterator_types = ["parallel", "parallel", "reduction", "parallel", "parallel", "reduction"]} ins(%[[PACK0]], %[[PACK1]] : tensor<1x1x5x6xf32>, tensor<1x1x6x5xf32>) outs(%[[PACK2]] : tensor<1x1x5x5xf32>) +// CHECK: %[[OUT:.+]] = tensor.unpack %[[VAL]] inner_dims_pos = [0, 1] inner_tiles = [5, 5] into %[[ARG2]] : tensor<1x1x5x5xf32> -> tensor<5x5xf32> +// CHECK: return %[[OUT]] : tensor<5x5xf32> +// CHECK: } // ----- @@ -183,15 +196,3 @@ func.func @batch_matmul_rewrite(%arg0: tensor<512x64x128xf32>, %arg1: tensor<512 // CHECK: %[[UNPACK:.+]] = tensor.unpack %[[GEN]] // CHECK-SAME: inner_dims_pos = [1, 2] inner_tiles = [32, 32] // CHECK-SAME: into %[[OUT]] : tensor<512x2x2x32x32xf32> -> tensor<512x64x64xf32> - -// ----- - -// CHECK-LABEL: batch_matmul_invalid_tiles -func.func @batch_matmul_invalid_tiles(%arg0: tensor<5x5x5xf32>, %arg1: tensor<5x5x5xf32>) -> tensor<5x5x5xf32> { - %0 = tensor.empty() : tensor<5x5x5xf32> - // CHECK: linalg.batch_matmul - // CHECK-NOT: linalg.generic - %1 = linalg.batch_matmul ins(%arg0, %arg1 : tensor<5x5x5xf32>, tensor<5x5x5xf32>) - outs(%0 : tensor<5x5x5xf32>) -> tensor<5x5x5xf32> - return %1 : tensor<5x5x5xf32> -} diff --git a/test/Passes/tpp-mapping.mlir b/test/Passes/tpp-mapping.mlir index e50c297bc..29db80d2a 100644 --- a/test/Passes/tpp-mapping.mlir +++ b/test/Passes/tpp-mapping.mlir @@ -15,9 +15,10 @@ func.func @conv_to_matmul(%img: tensor<1x5x5x3xf32>, %filter: tensor<3x3x3x8xf32 // CHECK: scf.for // CHECK: tensor.extract_slice{{[^:]+}}: tensor<1x5x5x3xf32> to tensor<3x3xf32> // CHECK: tensor.extract_slice{{[^:]+}}: tensor<3x3x3x8xf32> to tensor<3x8xf32> -// CHECK: tensor.extract_slice{{[^:]+}}: tensor<1x3x3x8xf32> to tensor<3x8xf32> +// CHECK: tensor.extract_slice{{[^:]+}}: tensor<1x1x3x8xf32> to tensor<3x8xf32> // CHECK: linalg.matmul{{.*}} -> tensor<3x8xf32> -// CHECK: tensor.insert_slice{{[^:]+}}: tensor<3x8xf32> into tensor<1x3x3x8xf32> +// CHECK: tensor.insert_slice{{[^:]+}}: tensor<3x8xf32> into tensor<1x1x3x8xf32> +// CHECK: tensor.insert_slice{{[^:]+}}: tensor<1x1x3x8xf32> into tensor<1x3x3x8xf32> // CHECK: } // -----