Skip to content

Commit

Permalink
Add bounds to the affine.if regions (#380)
Browse files Browse the repository at this point in the history
Add bounds to the affine.if regions
  • Loading branch information
josel-amd authored Oct 2, 2024
1 parent 3692340 commit 09ddec3
Show file tree
Hide file tree
Showing 3 changed files with 97 additions and 1 deletion.
3 changes: 2 additions & 1 deletion mlir/include/mlir/Dialect/Affine/IR/AffineOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -336,7 +336,8 @@ def AffineForOp : Affine_Op<"for",
def AffineIfOp : Affine_Op<"if",
[ImplicitAffineTerminator, RecursivelySpeculatable,
RecursiveMemoryEffects, NoRegionArguments,
DeclareOpInterfaceMethods<RegionBranchOpInterface>
DeclareOpInterfaceMethods<RegionBranchOpInterface,
["getRegionInvocationBounds"]>
]> {
let summary = "if-then-else operation";
let description = [{
Expand Down
7 changes: 7 additions & 0 deletions mlir/lib/Dialect/Affine/IR/AffineOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2765,6 +2765,13 @@ struct AlwaysTrueOrFalseIf : public OpRewritePattern<AffineIfOp> {
};
} // namespace

void AffineIfOp::getRegionInvocationBounds(
ArrayRef<Attribute> operands,
SmallVectorImpl<InvocationBounds> &invocationBounds) {
// Non-constant condition. Each region may be executed 0 or 1 times.
invocationBounds.assign(getNumRegions(), {0, 1});
}

/// AffineIfOp has two regions -- `then` and `else`. The flow of data should be
/// as follows: AffineIfOp -> `then`/`else` -> AffineIfOp
void AffineIfOp::getSuccessorRegions(
Expand Down
88 changes: 88 additions & 0 deletions mlir/test/Dialect/Affine/control-flow-sink.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
// RUN: mlir-opt -split-input-file -control-flow-sink %s | FileCheck %s

#set = affine_set<(d0) : (-d0 + 3 >= 0)>
#map = affine_map<(d0) -> (d0)>

func.func @test_affine_if_sink(%arg1: tensor<4xf32>) -> tensor<4xf32> {
%0 = tensor.empty() : tensor<4xf32>
%1 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel"]}
ins(%arg1: tensor<4xf32>) outs(%0: tensor<4xf32>) {
^bb0(%in: f32, %out: f32):
%index = linalg.index 0 : index
%const0 = arith.constant 0.0 : f32
%add = arith.addf %in, %in: f32
%4 = affine.if #set(%index) -> f32 {
affine.yield %add : f32
} else {
affine.yield %const0 : f32
}
linalg.yield %4 : f32
} -> (tensor<4xf32>)
return %1: tensor<4xf32>
}

// CHECK-LABEL: affine.if
// CHECK-NEXT: %[[ADD:.*]] = arith.addf
// CHECK-NEXT: affine.yield %[[ADD]] : f32
// CHECK-NEXT: } else {
// CHECK-NEXT: %[[ZERO:.*]] = arith.constant 0.000000e+00 : f32
// CHECK-NEXT: affine.yield %[[ZERO]] : f32
// CHECK-NEXT: }

// -----

#set = affine_set<(d0) : (-d0 + 3 >= 0)>
#map = affine_map<(d0) -> (d0)>

func.func @test_affine_if_sink_with_loop_independenct_code(%arg0: f32, %arg1: tensor<4xf32>) -> tensor<4xf32> {
%const0 = arith.constant 0.0 : f32
%const1 = arith.constant 1.0 : f32
%0 = tensor.empty() : tensor<4xf32>
%1 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel"]}
ins(%arg1: tensor<4xf32>) outs(%0: tensor<4xf32>) {
^bb0(%in: f32, %out: f32):
%index = linalg.index 0 : index
%4 = affine.if #set(%index) -> f32 {
affine.yield %const1 : f32
} else {
affine.yield %const0 : f32
}
linalg.yield %4 : f32
} -> (tensor<4xf32>)
return %1: tensor<4xf32>
}

// CHECK-LABEL: affine.if
// CHECK-NEXT: %[[C1:.*]] = arith.constant 1.0
// CHECK-NEXT: affine.yield %[[C1]] : f32
// CHECK-NEXT: } else {
// CHECK-NEXT: %[[C0:.*]] = arith.constant 0.0
// CHECK-NEXT: affine.yield %[[C0]] : f32
// CHECK-NEXT: }


// -----

func.func private @external(f32) -> ()

#map = affine_map<(d0) -> (d0)>

func.func @affine_if_no_else(%arg0: f32, %arg1: tensor<4xf32>) -> tensor<4xf32> {
%const1 = arith.constant 1.0 : f32
%0 = tensor.empty() : tensor<4xf32>
%1 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel"]}
ins(%arg1: tensor<4xf32>) outs(%0: tensor<4xf32>) {
^bb0(%in: f32, %out: f32):
%index = linalg.index 0 : index
affine.if affine_set<(d0) : (-d0 + 3 >= 0)>(%index) {
func.call @external(%const1) : (f32) -> ()
}
linalg.yield %in : f32
} -> (tensor<4xf32>)
return %1: tensor<4xf32>
}

// CHECK-LABEL: affine.if
// CHECK-NEXT: %[[C1:.*]] = arith.constant 1.0
// CHECK-NEXT: func.call @external(%[[C1]]) : (f32) -> ()
// CHECK-NEXT: }

0 comments on commit 09ddec3

Please sign in to comment.