Skip to content

Commit

Permalink
Generalize before packing propagation
Browse files Browse the repository at this point in the history
Shifts tpp-mapping named ops generalization step to allow packing
propagation when input IR contains named Linalg ops.
Currently, upstream propagation passes only work on linalg.generics.
  • Loading branch information
adam-smnk committed Jul 23, 2024
1 parent ea51a74 commit 89dfc1e
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 1 deletion.
4 changes: 3 additions & 1 deletion lib/TPP/PassBundles/TppMapping.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -67,11 +67,13 @@ struct TppMapping : public tpp::impl::TppMappingBase<TppMapping>,
// Run only canonicalizer at this stage as full cleanup (mostly CSE) can
// mess up tensor producer-consumer chains used for analysis in the
// following passes.
// Generalize named ops to allow packing propagation.
// TODO: Remove the generalization when upstream propagation is improved.
pm.addNestedPass<func::FuncOp>(createLinalgGeneralizeNamedOpsPass());
pm.addPass(createPropagatePackUnPack());
pm.addPass(createConstantFoldPack());
pm.addPass(createSimplifyAndCanonicalizePack());

pm.addNestedPass<func::FuncOp>(createLinalgGeneralizeNamedOpsPass());
pm.addPass(createCleanup());
pm.addNestedPass<func::FuncOp>(
createLinalgConvertCompareSelectToMaximumfPass());
Expand Down
22 changes: 22 additions & 0 deletions test/Passes/tpp-mapping.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -207,3 +207,25 @@ func.func @tile_and_fuse(%arg0: tensor<64x64xf32>, %arg1: tensor<64x64xf32>,
// CHECK-SAME:{{.*}}outs(%{{.+}} : tensor<32x32xf32>)
// CHECK: linalg.generic{{.*}}outs(%{{.+}} : tensor<32x32xf32>)
// CHECK: arith.maximumf

// -----

func.func @tile_and_fuse_named(%arg0: tensor<64x64xf32>, %arg1: tensor<64x64xf32>,
%arg2: tensor<64x64xf32>, %arg3: tensor<64x64xf32>) -> tensor<64x64xf32> {
%e = tensor.empty() : tensor<64x64xf32>
%0 = linalg.matmul ins(%arg0, %arg1 : tensor<64x64xf32>, tensor<64x64xf32>)
outs(%arg2 : tensor<64x64xf32>) -> tensor<64x64xf32>
%1 = linalg.add ins(%0, %arg3 : tensor<64x64xf32>, tensor<64x64xf32>)
outs(%e : tensor<64x64xf32>) -> tensor<64x64xf32>
return %1 : tensor<64x64xf32>
}

// CHECK-LABEL: tile_and_fuse_named(
// CHECK-COUNT-3: tensor.pack
// Fused matmul and relu
// CHECK: scf.forall
// CHECK: linalg.batch_reduce_matmul{{.*}}ins(%{{.+}}, %{{.+}} : tensor<2x32x32xf32>, tensor<2x32x32xf32>)
// CHECK-SAME:{{.*}}outs(%{{.+}} : tensor<32x32xf32>)
// CHECK: linalg.generic{{.*}}outs(%{{.+}} : tensor<32x32xf32>)
// CHECK: arith.addf
// CHECK: tensor.unpack

0 comments on commit 89dfc1e

Please sign in to comment.