From 24d80e165b816dfde21d32c31f8554fedba20647 Mon Sep 17 00:00:00 2001 From: MaheshRavishankar <1663364+MaheshRavishankar@users.noreply.github.com> Date: Wed, 4 Oct 2023 21:02:23 -0700 Subject: [PATCH] Preserve lowering config attribute during rematerialization. (#15103) --- .../Codegen/Common/RematerializeParallelOps.cpp | 13 +++---------- .../Common/test/rematerialize_parallel_ops.mlir | 5 +++-- 2 files changed, 6 insertions(+), 12 deletions(-) diff --git a/compiler/src/iree/compiler/Codegen/Common/RematerializeParallelOps.cpp b/compiler/src/iree/compiler/Codegen/Common/RematerializeParallelOps.cpp index 3676a7c27e94..bdd5b59e5126 100644 --- a/compiler/src/iree/compiler/Codegen/Common/RematerializeParallelOps.cpp +++ b/compiler/src/iree/compiler/Codegen/Common/RematerializeParallelOps.cpp @@ -50,18 +50,11 @@ struct RematerializeParallelOpsPattern FailureOr fusionResult = linalg::fuseElementwiseOps(rewriter, &opOperand); if (succeeded(fusionResult)) { - // Copy over lowering_config if after fusion we still see the same loop - // count to enable using this pass inside a CodeGen pipeline. - // TODO: This is hacky and it pretty much assumes all parallel producer - // ops which does not change loop structure at all. - if (auto linalgOp = dyn_cast(fusionResult->fusedOp)) { - if (genericOp.getNumLoops() == linalgOp.getNumLoops()) - if (Attribute attr = genericOp->getAttr("lowering_config")) - linalgOp->setAttr("lowering_config", attr); - } - auto replacements = fusionResult->fusedOp->getResults().take_back( genericOp.getNumResults()); + // Copy over any non native attributes for the operation. + auto prunedAttributeList = linalg::getPrunedAttributeList(genericOp); + fusionResult->fusedOp->setAttrs(prunedAttributeList); rewriter.replaceOp(genericOp, replacements); return success(); } diff --git a/compiler/src/iree/compiler/Codegen/Common/test/rematerialize_parallel_ops.mlir b/compiler/src/iree/compiler/Codegen/Common/test/rematerialize_parallel_ops.mlir index 0e60a8e35f11..2cb766a45e1e 100644 --- a/compiler/src/iree/compiler/Codegen/Common/test/rematerialize_parallel_ops.mlir +++ b/compiler/src/iree/compiler/Codegen/Common/test/rematerialize_parallel_ops.mlir @@ -1,4 +1,4 @@ -// RUN: iree-opt -iree-codegen-rematerialize-parallel-ops %s | FileCheck %s +// RUN: iree-opt -iree-codegen-rematerialize-parallel-ops --split-input-file %s | FileCheck %s func.func @merged_reduction_parallel(%0: tensor<1x40960xf32>, %1: tensor<1xf32>, %7: tensor<1xf32>) -> tensor<1x40960xf32> { @@ -25,7 +25,7 @@ func.func @merged_reduction_parallel(%0: tensor<1x40960xf32>, %1: tensor<1xf32>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%8, %7 : tensor<1x40960xf32>, tensor<1xf32>) - outs(%2 : tensor<1x40960xf32>) { + outs(%2 : tensor<1x40960xf32>) attrs = {foo = "foo"} { ^bb0(%in: f32, %in_2: f32, %out: f32): %10 = arith.divf %cst, %in_2 : f32 %11 = arith.mulf %in, %10 : f32 @@ -39,6 +39,7 @@ func.func @merged_reduction_parallel(%0: tensor<1x40960xf32>, %1: tensor<1xf32>, // CHECK-LABEL: func.func @merged_reduction_parallel // CHECK: %{{.+}} = linalg.generic +// CHECK-SAME: attrs = {foo = "foo"} // CHECK: arith.subf // CHECK-NEXT: math.exp // CHECK-NEXT: arith.divf