From 89dfc1ead49f20342afc8ad3a01d4c134df7c9e4 Mon Sep 17 00:00:00 2001 From: Adam Siemieniuk Date: Tue, 23 Jul 2024 19:17:23 +0200 Subject: [PATCH] Generalize before packing propagation Shifts tpp-mapping named ops generalization step to allow packing propagation when input IR contains named Linalg ops. Currently, upstream propagation passes only work on linalg.generics. --- lib/TPP/PassBundles/TppMapping.cpp | 4 +++- test/Passes/tpp-mapping.mlir | 22 ++++++++++++++++++++++ 2 files changed, 25 insertions(+), 1 deletion(-) diff --git a/lib/TPP/PassBundles/TppMapping.cpp b/lib/TPP/PassBundles/TppMapping.cpp index d39b62530..66238403a 100644 --- a/lib/TPP/PassBundles/TppMapping.cpp +++ b/lib/TPP/PassBundles/TppMapping.cpp @@ -67,11 +67,13 @@ struct TppMapping : public tpp::impl::TppMappingBase, // Run only canonicalizer at this stage as full cleanup (mostly CSE) can // mess up tensor producer-consumer chains used for analysis in the // following passes. + // Generalize named ops to allow packing propagation. + // TODO: Remove the generalization when upstream propagation is improved. + pm.addNestedPass(createLinalgGeneralizeNamedOpsPass()); pm.addPass(createPropagatePackUnPack()); pm.addPass(createConstantFoldPack()); pm.addPass(createSimplifyAndCanonicalizePack()); - pm.addNestedPass(createLinalgGeneralizeNamedOpsPass()); pm.addPass(createCleanup()); pm.addNestedPass( createLinalgConvertCompareSelectToMaximumfPass()); diff --git a/test/Passes/tpp-mapping.mlir b/test/Passes/tpp-mapping.mlir index e50c297bc..ed2f19aa3 100644 --- a/test/Passes/tpp-mapping.mlir +++ b/test/Passes/tpp-mapping.mlir @@ -207,3 +207,25 @@ func.func @tile_and_fuse(%arg0: tensor<64x64xf32>, %arg1: tensor<64x64xf32>, // CHECK-SAME:{{.*}}outs(%{{.+}} : tensor<32x32xf32>) // CHECK: linalg.generic{{.*}}outs(%{{.+}} : tensor<32x32xf32>) // CHECK: arith.maximumf + +// ----- + +func.func @tile_and_fuse_named(%arg0: tensor<64x64xf32>, %arg1: tensor<64x64xf32>, + %arg2: tensor<64x64xf32>, %arg3: tensor<64x64xf32>) -> tensor<64x64xf32> { + %e = tensor.empty() : tensor<64x64xf32> + %0 = linalg.matmul ins(%arg0, %arg1 : tensor<64x64xf32>, tensor<64x64xf32>) + outs(%arg2 : tensor<64x64xf32>) -> tensor<64x64xf32> + %1 = linalg.add ins(%0, %arg3 : tensor<64x64xf32>, tensor<64x64xf32>) + outs(%e : tensor<64x64xf32>) -> tensor<64x64xf32> + return %1 : tensor<64x64xf32> +} + +// CHECK-LABEL: tile_and_fuse_named( +// CHECK-COUNT-3: tensor.pack +// Fused matmul and relu +// CHECK: scf.forall +// CHECK: linalg.batch_reduce_matmul{{.*}}ins(%{{.+}}, %{{.+}} : tensor<2x32x32xf32>, tensor<2x32x32xf32>) +// CHECK-SAME:{{.*}}outs(%{{.+}} : tensor<32x32xf32>) +// CHECK: linalg.generic{{.*}}outs(%{{.+}} : tensor<32x32xf32>) +// CHECK: arith.addf +// CHECK: tensor.unpack