diff --git a/lib/TPP/PassBundles/TppMapping.cpp b/lib/TPP/PassBundles/TppMapping.cpp index d39b62530..66238403a 100644 --- a/lib/TPP/PassBundles/TppMapping.cpp +++ b/lib/TPP/PassBundles/TppMapping.cpp @@ -67,11 +67,13 @@ struct TppMapping : public tpp::impl::TppMappingBase, // Run only canonicalizer at this stage as full cleanup (mostly CSE) can // mess up tensor producer-consumer chains used for analysis in the // following passes. + // Generalize named ops to allow packing propagation. + // TODO: Remove the generalization when upstream propagation is improved. + pm.addNestedPass(createLinalgGeneralizeNamedOpsPass()); pm.addPass(createPropagatePackUnPack()); pm.addPass(createConstantFoldPack()); pm.addPass(createSimplifyAndCanonicalizePack()); - pm.addNestedPass(createLinalgGeneralizeNamedOpsPass()); pm.addPass(createCleanup()); pm.addNestedPass( createLinalgConvertCompareSelectToMaximumfPass()); diff --git a/test/Passes/tpp-mapping.mlir b/test/Passes/tpp-mapping.mlir index e50c297bc..ed2f19aa3 100644 --- a/test/Passes/tpp-mapping.mlir +++ b/test/Passes/tpp-mapping.mlir @@ -207,3 +207,25 @@ func.func @tile_and_fuse(%arg0: tensor<64x64xf32>, %arg1: tensor<64x64xf32>, // CHECK-SAME:{{.*}}outs(%{{.+}} : tensor<32x32xf32>) // CHECK: linalg.generic{{.*}}outs(%{{.+}} : tensor<32x32xf32>) // CHECK: arith.maximumf + +// ----- + +func.func @tile_and_fuse_named(%arg0: tensor<64x64xf32>, %arg1: tensor<64x64xf32>, + %arg2: tensor<64x64xf32>, %arg3: tensor<64x64xf32>) -> tensor<64x64xf32> { + %e = tensor.empty() : tensor<64x64xf32> + %0 = linalg.matmul ins(%arg0, %arg1 : tensor<64x64xf32>, tensor<64x64xf32>) + outs(%arg2 : tensor<64x64xf32>) -> tensor<64x64xf32> + %1 = linalg.add ins(%0, %arg3 : tensor<64x64xf32>, tensor<64x64xf32>) + outs(%e : tensor<64x64xf32>) -> tensor<64x64xf32> + return %1 : tensor<64x64xf32> +} + +// CHECK-LABEL: tile_and_fuse_named( +// CHECK-COUNT-3: tensor.pack +// Fused matmul and relu +// CHECK: scf.forall +// CHECK: linalg.batch_reduce_matmul{{.*}}ins(%{{.+}}, %{{.+}} : tensor<2x32x32xf32>, tensor<2x32x32xf32>) +// CHECK-SAME:{{.*}}outs(%{{.+}} : tensor<32x32xf32>) +// CHECK: linalg.generic{{.*}}outs(%{{.+}} : tensor<32x32xf32>) +// CHECK: arith.addf +// CHECK: tensor.unpack