Skip to content

Commit

Permalink
[xls][mlir] Add a new optimization -optimize-spawns
Browse files Browse the repository at this point in the history
This takes code like:

  sproc (%arg0) {
   spawns {
    %in, %out = chan<>
    spawn @fn(%in)
    yield %arg0, %out
   }
   next(%arg0, %out) {
    %0 = sblocking_receive %arg0
    ssend %0, %out
   }
  }

And optimizes the receive/send pair away, connecting the relevant channels in the
spawns region instead. For example:

  sproc (%arg0) {
   spawns {
    spawn @fn (%arg0)
    yield
   }
   next {}
  }

PiperOrigin-RevId: 700645505
  • Loading branch information
James Molloy authored and copybara-github committed Nov 27, 2024
1 parent 72da07d commit 6bf8972
Show file tree
Hide file tree
Showing 7 changed files with 387 additions and 19 deletions.
17 changes: 17 additions & 0 deletions xls/contrib/mlir/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -377,7 +377,9 @@ cc_library(
":index_type_conversion", # buildcleaner: keep
":instantiate_eprocs", # buildcleaner: keep
":math_to_xls", # buildcleaner: keep
":mlir_xls",
":normalize_xls_calls", # buildcleaner: keep
":optimize_spawns", # buildcleaner: keep
":proc_elaboration", # buildcleaner: keep
":procify_loops", # buildcleaner: keep
":scalarize", # buildcleaner: keep
Expand Down Expand Up @@ -475,6 +477,20 @@ cc_library(
],
)

cc_library(
name = "optimize_spawns",
srcs = ["transforms/optimize_spawns.cc"],
deps = [
":mlir_xls",
":xls_transforms_passes_inc_gen",
"@llvm-project//mlir:FuncDialect",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Pass",
"@llvm-project//mlir:Support",
"@llvm-project//mlir:TransformUtils",
],
)

cc_binary(
name = "xls_opt",
srcs = ["tools/xls_opt/xls_opt.cc"],
Expand All @@ -486,6 +502,7 @@ cc_binary(
":instantiate_eprocs",
":mlir_xls",
":normalize_xls_calls",
":optimize_spawns",
":proc_elaboration",
":proc_utils",
":procify_loops",
Expand Down
16 changes: 15 additions & 1 deletion xls/contrib/mlir/IR/xls_ops.td
Original file line number Diff line number Diff line change
Expand Up @@ -1158,7 +1158,7 @@ def Xls_CountedForOp : Xls_Op<"counted_for",
// Sequencing operations
//===----------------------------------------------------------------------===//

def Xls_AfterAllOp : Xls_Op<"after_all", []> {
def Xls_AfterAllOp : Xls_Op<"after_all", [Pure]> {
let summary = "Constructs partial orderings among channel operations";
let description = [{
Used to construct partial orderings among channel operations.
Expand Down Expand Up @@ -1569,6 +1569,20 @@ def Xls_SprocOp : Xls_Op<"sproc", [
::llvm::ArrayRef<::mlir::Type> getResultTypes() {
return {};
}

// Adds the given values to the yield op and the next region. Returns the
// index of the first value in both the yield operands and the
// getNextChannels().
int64_t addYieldedChannels(ValueRange values) {
Operation* yield = getSpawns().getBlocks().front().getTerminator();
int64_t index = getNext().getNumArguments() - getStateArguments().size();
yield->insertOperands(index, values);
int64_t i = index;
for (Value value : values) {
getNext().insertArgument(i++, value.getType(), value.getLoc());
}
return index;
}
}];
}

Expand Down
30 changes: 12 additions & 18 deletions xls/contrib/mlir/testdata/integration/procify.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -62,10 +62,8 @@ xls.sproc @reduce() top attributes {boundary_channel_names = []} {
// CHECK-MLIR: %tkn_out, %result = xls.blocking_receive %0, @body_arg_0 : i32
// CHECK-MLIR: %tkn_out_0, %result_1 = xls.blocking_receive %0, @body_arg_1 : i32
// CHECK-MLIR: %1 = xls.add %result, %result_1 : i32
// CHECK-MLIR: %2 = xls.after_all : !xls.token
// CHECK-MLIR: %3 = xls.trace %2, "sum_next: {}, {}"(%result, %1) : i32, i32
// CHECK-MLIR: %4 = xls.after_all : !xls.token
// CHECK-MLIR: %5 = xls.send %4, %1, @body_result_0 : i32
// CHECK-MLIR: %2 = xls.trace %0, "sum_next: {}, {}"(%result, %1) : i32, i32
// CHECK-MLIR: %3 = xls.send %0, %1, @body_result_0 : i32
// CHECK-MLIR: xls.yield %arg0 : i32
// CHECK-MLIR: }
// CHECK-MLIR: xls.eproc @reduce_for_controller_1_1(%arg0: i32) zeroinitializer {
Expand All @@ -76,20 +74,16 @@ xls.sproc @reduce() top attributes {boundary_channel_names = []} {
// CHECK-MLIR: %4 = xls.not %3 : i1
// CHECK-MLIR: %5 = xls.after_all : !xls.token
// CHECK-MLIR: %tkn_out, %result = xls.blocking_receive %5, %3, @for_arg_0 : i32
// CHECK-MLIR: %6 = xls.after_all : !xls.token
// CHECK-MLIR: %tkn_out_0, %result_1 = xls.blocking_receive %6, %4, @body_result_0 : i32
// CHECK-MLIR: %7 = xls.sel %3 in [%result_1] else %result : (i1, [i32], i32) -> i32
// CHECK-MLIR: %8 = xls.eq %arg0, %2 : (i32, i32) -> i1
// CHECK-MLIR: %9 = xls.not %8 : i1
// CHECK-MLIR: %10 = xls.after_all : !xls.token
// CHECK-MLIR: %11 = xls.send %10, %7, %8, @for_result_0 : i32
// CHECK-MLIR: %12 = xls.after_all : !xls.token
// CHECK-MLIR: %13 = xls.send %12, %7, %9, @body_arg_1 : i32
// CHECK-MLIR: %14 = xls.after_all : !xls.token
// CHECK-MLIR: %15 = xls.send %14, %arg0, %9, @body_arg_0 : i32
// CHECK-MLIR: %16 = xls.add %arg0, %1 : i32
// CHECK-MLIR: %17 = xls.sel %8 in [%16] else %0 : (i1, [i32], i32) -> i32
// CHECK-MLIR: xls.yield %17 : i32
// CHECK-MLIR: %tkn_out_0, %result_1 = xls.blocking_receive %5, %4, @body_result_0 : i32
// CHECK-MLIR: %6 = xls.sel %3 in [%result_1] else %result : (i1, [i32], i32) -> i32
// CHECK-MLIR: %7 = xls.eq %arg0, %2 : (i32, i32) -> i1
// CHECK-MLIR: %8 = xls.not %7 : i1
// CHECK-MLIR: %9 = xls.send %5, %6, %7, @for_result_0 : i32
// CHECK-MLIR: %10 = xls.send %5, %6, %8, @body_arg_1 : i32
// CHECK-MLIR: %11 = xls.send %5, %arg0, %8, @body_arg_0 : i32
// CHECK-MLIR: %12 = xls.add %arg0, %1 : i32
// CHECK-MLIR: %13 = xls.sel %7 in [%12] else %0 : (i1, [i32], i32) -> i32
// CHECK-MLIR: xls.yield %13 : i32
// CHECK-MLIR: }
// CHECK-MLIR: xls.eproc @reduce_2_0(%arg0: i32) zeroinitializer attributes {min_pipeline_stages = 2 : i64} {
// CHECK-MLIR: %0 = xls.after_all : !xls.token
Expand Down
177 changes: 177 additions & 0 deletions xls/contrib/mlir/testdata/optimize_spawns.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,177 @@
// RUN: xls/contrib/mlir/xls_opt -optimize-spawns %s 2>&1 | FileCheck %s

xls.sproc @fn(%arg0: !xls.schan<tensor<8xi32>, in>) {
spawns {
xls.yield
}
next(%arg1: i32) zeroinitializer {
xls.yield %arg1 : i32
}
}

xls.sproc @fn2(%arg0: !xls.schan<tensor<8xi32>, out>) {
spawns {
xls.yield
}
next(%arg1: i32) zeroinitializer {
xls.yield %arg1 : i32
}
}

// Consumes an argument and passes to a spawn.
// CHECK: xls.sproc @consume_arg(%arg0: !xls.schan<tensor<8xi32>, in>) top {
// CHECK: spawns {
// CHECK: %out, %in = xls.schan<tensor<8xi32>>("x")
// CHECK: xls.spawn @fn(%arg0) : !xls.schan<tensor<8xi32>, in>
// CHECK: xls.yield
// CHECK: }
// CHECK: next (%arg0: i32) zeroinitializer {
// CHECK: xls.yield %arg0 : i32
// CHECK: }
// CHECK: }
xls.sproc @consume_arg(%arg0: !xls.schan<tensor<8xi32>, in>) top {
spawns {
%out, %in = xls.schan<tensor<8xi32>>("x")
xls.spawn @fn(%in) : !xls.schan<tensor<8xi32>, in>
xls.yield %arg0, %out : !xls.schan<tensor<8xi32>, in>, !xls.schan<tensor<8xi32>, out>
}
next (%arg0: !xls.schan<tensor<8xi32>, in>, %arg1: !xls.schan<tensor<8xi32>, out>, %arg4: i32) zeroinitializer {
%0 = xls.after_all : !xls.token
%tkn_out, %result = xls.sblocking_receive %0, %arg0 : (!xls.token, !xls.schan<tensor<8xi32>, in>) -> (!xls.token, tensor<8xi32>)
%1 = xls.after_all : !xls.token
%2 = xls.ssend %1, %result, %arg1 : (!xls.token, tensor<8xi32>, !xls.schan<tensor<8xi32>, out>) -> !xls.token
xls.yield %arg4 : i32
}
}

// Produces a result from a spawn.
// CHECK: xls.sproc @produce_result(%arg0: !xls.schan<tensor<8xi32>, out>) top {
// CHECK: spawns {
// CHECK: %out, %in = xls.schan<tensor<8xi32>>("x")
// CHECK: xls.spawn @fn2(%arg0) : !xls.schan<tensor<8xi32>, out>
// CHECK: xls.yield
// CHECK: }
// CHECK: next (%arg0: i32) zeroinitializer {
// CHECK: xls.yield %arg0 : i32
// CHECK: }
// CHECK: }
xls.sproc @produce_result(%arg0: !xls.schan<tensor<8xi32>, out>) top {
spawns {
%out, %in = xls.schan<tensor<8xi32>>("x")
xls.spawn @fn2(%out) : !xls.schan<tensor<8xi32>, out>
xls.yield %in, %arg0 : !xls.schan<tensor<8xi32>, in>, !xls.schan<tensor<8xi32>, out>
}
next (%arg0: !xls.schan<tensor<8xi32>, in>, %arg1: !xls.schan<tensor<8xi32>, out>, %arg2: i32) zeroinitializer {
%0 = xls.after_all : !xls.token
%tkn_out, %result = xls.sblocking_receive %0, %arg0 : (!xls.token, !xls.schan<tensor<8xi32>, in>) -> (!xls.token, tensor<8xi32>)
%2 = xls.ssend %0, %result, %arg1 : (!xls.token, tensor<8xi32>, !xls.schan<tensor<8xi32>, out>) -> !xls.token
xls.yield %arg2 : i32
}
}

// Contracts away the interior channel.
// CHECK: xls.sproc @contract_away_interior_channel() top {
// CHECK: spawns {
// CHECK: %out, %in = xls.schan<tensor<8xi32>>("x")
// CHECK: xls.spawn @fn2(%out) : !xls.schan<tensor<8xi32>, out>
// CHECK: xls.spawn @fn(%in) : !xls.schan<tensor<8xi32>, in>
// CHECK: xls.yield
// CHECK: }
// CHECK: next (%arg0: i32) zeroinitializer {
// CHECK: xls.yield %arg0 : i32
// CHECK: }
// CHECK: }
xls.sproc @contract_away_interior_channel() top {
spawns {
%out, %in = xls.schan<tensor<8xi32>>("x")
xls.spawn @fn2(%out) : !xls.schan<tensor<8xi32>, out>
%out2, %in2 = xls.schan<tensor<8xi32>>("x")
xls.spawn @fn(%in2) : !xls.schan<tensor<8xi32>, in>
xls.yield %in, %out2 : !xls.schan<tensor<8xi32>, in>, !xls.schan<tensor<8xi32>, out>
}
next (%arg0: !xls.schan<tensor<8xi32>, in>,%arg1: !xls.schan<tensor<8xi32>, out>, %arg2: i32) zeroinitializer {
%0 = xls.after_all : !xls.token
%tkn_out, %result = xls.sblocking_receive %0, %arg0 : (!xls.token, !xls.schan<tensor<8xi32>, in>) -> (!xls.token, tensor<8xi32>)
%2 = xls.ssend %0, %result, %arg1 : (!xls.token, tensor<8xi32>, !xls.schan<tensor<8xi32>, out>) -> !xls.token
xls.yield %arg2 : i32
}
}

// The result from this receive is used twice, so we can't contract away the
// interior channel.
// CHECK: xls.sproc @receive_used_twice() top {
// CHECK: spawns {
// CHECK: %out, %in = xls.schan<tensor<8xi32>>("x")
// CHECK: xls.spawn @fn2(%out) : !xls.schan<tensor<8xi32>, out>
// CHECK: %out_0, %in_1 = xls.schan<tensor<8xi32>>("x")
// CHECK: xls.spawn @fn(%in_1) : !xls.schan<tensor<8xi32>, in>
// CHECK: xls.yield %in, %out_0 : !xls.schan<tensor<8xi32>, in>, !xls.schan<tensor<8xi32>, out>
// CHECK: }
xls.sproc @receive_used_twice() top {
spawns {
%out, %in = xls.schan<tensor<8xi32>>("x")
xls.spawn @fn2(%out) : !xls.schan<tensor<8xi32>, out>
%out2, %in2 = xls.schan<tensor<8xi32>>("x")
xls.spawn @fn(%in2) : !xls.schan<tensor<8xi32>, in>
xls.yield %in, %out2 : !xls.schan<tensor<8xi32>, in>, !xls.schan<tensor<8xi32>, out>
}
next (%arg0: !xls.schan<tensor<8xi32>, in>,%arg1: !xls.schan<tensor<8xi32>, out>, %arg2: i32) zeroinitializer {
%0 = xls.after_all : !xls.token
%tkn_out, %result = xls.sblocking_receive %0, %arg0 : (!xls.token, !xls.schan<tensor<8xi32>, in>) -> (!xls.token, tensor<8xi32>)
%2 = xls.ssend %0, %result, %arg1 : (!xls.token, tensor<8xi32>, !xls.schan<tensor<8xi32>, out>) -> !xls.token
%3 = xls.ssend %0, %result, %arg1 : (!xls.token, tensor<8xi32>, !xls.schan<tensor<8xi32>, out>) -> !xls.token
xls.yield %arg2 : i32
}
}

// The send is predicated, so we can't contract away the interior channel.
// CHECK: xls.sproc @send_predicated() top {
// CHECK: spawns {
// CHECK: %out, %in = xls.schan<tensor<8xi32>>("x")
// CHECK: xls.spawn @fn2(%out) : !xls.schan<tensor<8xi32>, out>
// CHECK: %out_0, %in_1 = xls.schan<tensor<8xi32>>("x")
// CHECK: xls.spawn @fn(%in_1) : !xls.schan<tensor<8xi32>, in>
// CHECK: xls.yield %in, %out_0 : !xls.schan<tensor<8xi32>, in>, !xls.schan<tensor<8xi32>, out>
// CHECK: }
xls.sproc @send_predicated() top {
spawns {
%out, %in = xls.schan<tensor<8xi32>>("x")
xls.spawn @fn2(%out) : !xls.schan<tensor<8xi32>, out>
%out2, %in2 = xls.schan<tensor<8xi32>>("x")
xls.spawn @fn(%in2) : !xls.schan<tensor<8xi32>, in>
xls.yield %in, %out2 : !xls.schan<tensor<8xi32>, in>, !xls.schan<tensor<8xi32>, out>
}
next (%arg0: !xls.schan<tensor<8xi32>, in>,%arg1: !xls.schan<tensor<8xi32>, out>, %arg2: i32) zeroinitializer {
%0 = xls.after_all : !xls.token
%tkn_out, %result = xls.sblocking_receive %0, %arg0 : (!xls.token, !xls.schan<tensor<8xi32>, in>) -> (!xls.token, tensor<8xi32>)
%true = arith.constant 1 : i1
%2 = xls.ssend %0, %result, %arg1, %true : (!xls.token, tensor<8xi32>, !xls.schan<tensor<8xi32>, out>, i1) -> !xls.token
xls.yield %arg2 : i32
}
}

// The receive is predicated, so we can't contract away the interior channel.
// CHECK: xls.sproc @recv_predicated() top {
// CHECK: spawns {
// CHECK: %out, %in = xls.schan<tensor<8xi32>>("x")
// CHECK: xls.spawn @fn2(%out) : !xls.schan<tensor<8xi32>, out>
// CHECK: %out_0, %in_1 = xls.schan<tensor<8xi32>>("x")
// CHECK: xls.spawn @fn(%in_1) : !xls.schan<tensor<8xi32>, in>
// CHECK: xls.yield %in, %out_0 : !xls.schan<tensor<8xi32>, in>, !xls.schan<tensor<8xi32>, out>
// CHECK: }
xls.sproc @recv_predicated() top {
spawns {
%out, %in = xls.schan<tensor<8xi32>>("x")
xls.spawn @fn2(%out) : !xls.schan<tensor<8xi32>, out>
%out2, %in2 = xls.schan<tensor<8xi32>>("x")
xls.spawn @fn(%in2) : !xls.schan<tensor<8xi32>, in>
xls.yield %in, %out2 : !xls.schan<tensor<8xi32>, in>, !xls.schan<tensor<8xi32>, out>
}
next (%arg0: !xls.schan<tensor<8xi32>, in>,%arg1: !xls.schan<tensor<8xi32>, out>, %arg2: i32) zeroinitializer {
%0 = xls.after_all : !xls.token
%true = arith.constant 1 : i1
%tkn_out, %result = xls.sblocking_receive %0, %arg0, %true : (!xls.token, !xls.schan<tensor<8xi32>, in>, i1) -> (!xls.token, tensor<8xi32>)
%2 = xls.ssend %0, %result, %arg1 : (!xls.token, tensor<8xi32>, !xls.schan<tensor<8xi32>, out>) -> !xls.token
xls.yield %arg2 : i32
}
}
Loading

0 comments on commit 6bf8972

Please sign in to comment.