Skip to content

Commit

Permalink
Preserve lowering config attribute during rematerialization. (iree-or…
Browse files Browse the repository at this point in the history
  • Loading branch information
MaheshRavishankar authored Oct 5, 2023
1 parent d76a104 commit 24d80e1
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 12 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -50,18 +50,11 @@ struct RematerializeParallelOpsPattern
FailureOr<linalg::ElementwiseOpFusionResult> fusionResult =
linalg::fuseElementwiseOps(rewriter, &opOperand);
if (succeeded(fusionResult)) {
// Copy over lowering_config if after fusion we still see the same loop
// count to enable using this pass inside a CodeGen pipeline.
// TODO: This is hacky and it pretty much assumes all parallel producer
// ops which does not change loop structure at all.
if (auto linalgOp = dyn_cast<linalg::LinalgOp>(fusionResult->fusedOp)) {
if (genericOp.getNumLoops() == linalgOp.getNumLoops())
if (Attribute attr = genericOp->getAttr("lowering_config"))
linalgOp->setAttr("lowering_config", attr);
}

auto replacements = fusionResult->fusedOp->getResults().take_back(
genericOp.getNumResults());
// Copy over any non native attributes for the operation.
auto prunedAttributeList = linalg::getPrunedAttributeList(genericOp);
fusionResult->fusedOp->setAttrs(prunedAttributeList);
rewriter.replaceOp(genericOp, replacements);
return success();
}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// RUN: iree-opt -iree-codegen-rematerialize-parallel-ops %s | FileCheck %s
// RUN: iree-opt -iree-codegen-rematerialize-parallel-ops --split-input-file %s | FileCheck %s

func.func @merged_reduction_parallel(%0: tensor<1x40960xf32>, %1: tensor<1xf32>, %7: tensor<1xf32>)
-> tensor<1x40960xf32> {
Expand All @@ -25,7 +25,7 @@ func.func @merged_reduction_parallel(%0: tensor<1x40960xf32>, %1: tensor<1xf32>,
affine_map<(d0, d1) -> (d0, d1)>],
iterator_types = ["parallel", "parallel"]}
ins(%8, %7 : tensor<1x40960xf32>, tensor<1xf32>)
outs(%2 : tensor<1x40960xf32>) {
outs(%2 : tensor<1x40960xf32>) attrs = {foo = "foo"} {
^bb0(%in: f32, %in_2: f32, %out: f32):
%10 = arith.divf %cst, %in_2 : f32
%11 = arith.mulf %in, %10 : f32
Expand All @@ -39,6 +39,7 @@ func.func @merged_reduction_parallel(%0: tensor<1x40960xf32>, %1: tensor<1xf32>,

// CHECK-LABEL: func.func @merged_reduction_parallel
// CHECK: %{{.+}} = linalg.generic
// CHECK-SAME: attrs = {foo = "foo"}
// CHECK: arith.subf
// CHECK-NEXT: math.exp
// CHECK-NEXT: arith.divf
Expand Down

0 comments on commit 24d80e1

Please sign in to comment.