diff --git a/xls/contrib/mlir/BUILD b/xls/contrib/mlir/BUILD index d162333664..049ab1f5e7 100644 --- a/xls/contrib/mlir/BUILD +++ b/xls/contrib/mlir/BUILD @@ -377,7 +377,9 @@ cc_library( ":index_type_conversion", # buildcleaner: keep ":instantiate_eprocs", # buildcleaner: keep ":math_to_xls", # buildcleaner: keep + ":mlir_xls", ":normalize_xls_calls", # buildcleaner: keep + ":optimize_spawns", # buildcleaner: keep ":proc_elaboration", # buildcleaner: keep ":procify_loops", # buildcleaner: keep ":scalarize", # buildcleaner: keep @@ -475,6 +477,20 @@ cc_library( ], ) +cc_library( + name = "optimize_spawns", + srcs = ["transforms/optimize_spawns.cc"], + deps = [ + ":mlir_xls", + ":xls_transforms_passes_inc_gen", + "@llvm-project//mlir:FuncDialect", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:Support", + "@llvm-project//mlir:TransformUtils", + ], +) + cc_binary( name = "xls_opt", srcs = ["tools/xls_opt/xls_opt.cc"], @@ -486,6 +502,7 @@ cc_binary( ":instantiate_eprocs", ":mlir_xls", ":normalize_xls_calls", + ":optimize_spawns", ":proc_elaboration", ":proc_utils", ":procify_loops", diff --git a/xls/contrib/mlir/IR/xls_ops.td b/xls/contrib/mlir/IR/xls_ops.td index 131863d9d6..7086201aa9 100644 --- a/xls/contrib/mlir/IR/xls_ops.td +++ b/xls/contrib/mlir/IR/xls_ops.td @@ -1158,7 +1158,7 @@ def Xls_CountedForOp : Xls_Op<"counted_for", // Sequencing operations //===----------------------------------------------------------------------===// -def Xls_AfterAllOp : Xls_Op<"after_all", []> { +def Xls_AfterAllOp : Xls_Op<"after_all", [Pure]> { let summary = "Constructs partial orderings among channel operations"; let description = [{ Used to construct partial orderings among channel operations. @@ -1569,6 +1569,20 @@ def Xls_SprocOp : Xls_Op<"sproc", [ ::llvm::ArrayRef<::mlir::Type> getResultTypes() { return {}; } + + // Adds the given values to the yield op and the next region. Returns the + // index of the first value in both the yield operands and the + // getNextChannels(). + int64_t addYieldedChannels(ValueRange values) { + Operation* yield = getSpawns().getBlocks().front().getTerminator(); + int64_t index = getNext().getNumArguments() - getStateArguments().size(); + yield->insertOperands(index, values); + int64_t i = index; + for (Value value : values) { + getNext().insertArgument(i++, value.getType(), value.getLoc()); + } + return index; + } }]; } diff --git a/xls/contrib/mlir/testdata/integration/procify.mlir b/xls/contrib/mlir/testdata/integration/procify.mlir index 4956bda7ba..a10311cfcd 100644 --- a/xls/contrib/mlir/testdata/integration/procify.mlir +++ b/xls/contrib/mlir/testdata/integration/procify.mlir @@ -62,10 +62,8 @@ xls.sproc @reduce() top attributes {boundary_channel_names = []} { // CHECK-MLIR: %tkn_out, %result = xls.blocking_receive %0, @body_arg_0 : i32 // CHECK-MLIR: %tkn_out_0, %result_1 = xls.blocking_receive %0, @body_arg_1 : i32 // CHECK-MLIR: %1 = xls.add %result, %result_1 : i32 -// CHECK-MLIR: %2 = xls.after_all : !xls.token -// CHECK-MLIR: %3 = xls.trace %2, "sum_next: {}, {}"(%result, %1) : i32, i32 -// CHECK-MLIR: %4 = xls.after_all : !xls.token -// CHECK-MLIR: %5 = xls.send %4, %1, @body_result_0 : i32 +// CHECK-MLIR: %2 = xls.trace %0, "sum_next: {}, {}"(%result, %1) : i32, i32 +// CHECK-MLIR: %3 = xls.send %0, %1, @body_result_0 : i32 // CHECK-MLIR: xls.yield %arg0 : i32 // CHECK-MLIR: } // CHECK-MLIR: xls.eproc @reduce_for_controller_1_1(%arg0: i32) zeroinitializer { @@ -76,20 +74,16 @@ xls.sproc @reduce() top attributes {boundary_channel_names = []} { // CHECK-MLIR: %4 = xls.not %3 : i1 // CHECK-MLIR: %5 = xls.after_all : !xls.token // CHECK-MLIR: %tkn_out, %result = xls.blocking_receive %5, %3, @for_arg_0 : i32 -// CHECK-MLIR: %6 = xls.after_all : !xls.token -// CHECK-MLIR: %tkn_out_0, %result_1 = xls.blocking_receive %6, %4, @body_result_0 : i32 -// CHECK-MLIR: %7 = xls.sel %3 in [%result_1] else %result : (i1, [i32], i32) -> i32 -// CHECK-MLIR: %8 = xls.eq %arg0, %2 : (i32, i32) -> i1 -// CHECK-MLIR: %9 = xls.not %8 : i1 -// CHECK-MLIR: %10 = xls.after_all : !xls.token -// CHECK-MLIR: %11 = xls.send %10, %7, %8, @for_result_0 : i32 -// CHECK-MLIR: %12 = xls.after_all : !xls.token -// CHECK-MLIR: %13 = xls.send %12, %7, %9, @body_arg_1 : i32 -// CHECK-MLIR: %14 = xls.after_all : !xls.token -// CHECK-MLIR: %15 = xls.send %14, %arg0, %9, @body_arg_0 : i32 -// CHECK-MLIR: %16 = xls.add %arg0, %1 : i32 -// CHECK-MLIR: %17 = xls.sel %8 in [%16] else %0 : (i1, [i32], i32) -> i32 -// CHECK-MLIR: xls.yield %17 : i32 +// CHECK-MLIR: %tkn_out_0, %result_1 = xls.blocking_receive %5, %4, @body_result_0 : i32 +// CHECK-MLIR: %6 = xls.sel %3 in [%result_1] else %result : (i1, [i32], i32) -> i32 +// CHECK-MLIR: %7 = xls.eq %arg0, %2 : (i32, i32) -> i1 +// CHECK-MLIR: %8 = xls.not %7 : i1 +// CHECK-MLIR: %9 = xls.send %5, %6, %7, @for_result_0 : i32 +// CHECK-MLIR: %10 = xls.send %5, %6, %8, @body_arg_1 : i32 +// CHECK-MLIR: %11 = xls.send %5, %arg0, %8, @body_arg_0 : i32 +// CHECK-MLIR: %12 = xls.add %arg0, %1 : i32 +// CHECK-MLIR: %13 = xls.sel %7 in [%12] else %0 : (i1, [i32], i32) -> i32 +// CHECK-MLIR: xls.yield %13 : i32 // CHECK-MLIR: } // CHECK-MLIR: xls.eproc @reduce_2_0(%arg0: i32) zeroinitializer attributes {min_pipeline_stages = 2 : i64} { // CHECK-MLIR: %0 = xls.after_all : !xls.token diff --git a/xls/contrib/mlir/testdata/optimize_spawns.mlir b/xls/contrib/mlir/testdata/optimize_spawns.mlir new file mode 100644 index 0000000000..63792f35ab --- /dev/null +++ b/xls/contrib/mlir/testdata/optimize_spawns.mlir @@ -0,0 +1,177 @@ +// RUN: xls/contrib/mlir/xls_opt -optimize-spawns %s 2>&1 | FileCheck %s + +xls.sproc @fn(%arg0: !xls.schan, in>) { + spawns { + xls.yield + } + next(%arg1: i32) zeroinitializer { + xls.yield %arg1 : i32 + } +} + +xls.sproc @fn2(%arg0: !xls.schan, out>) { + spawns { + xls.yield + } + next(%arg1: i32) zeroinitializer { + xls.yield %arg1 : i32 + } +} + +// Consumes an argument and passes to a spawn. +// CHECK: xls.sproc @consume_arg(%arg0: !xls.schan, in>) top { +// CHECK: spawns { +// CHECK: %out, %in = xls.schan>("x") +// CHECK: xls.spawn @fn(%arg0) : !xls.schan, in> +// CHECK: xls.yield +// CHECK: } +// CHECK: next (%arg0: i32) zeroinitializer { +// CHECK: xls.yield %arg0 : i32 +// CHECK: } +// CHECK: } +xls.sproc @consume_arg(%arg0: !xls.schan, in>) top { + spawns { + %out, %in = xls.schan>("x") + xls.spawn @fn(%in) : !xls.schan, in> + xls.yield %arg0, %out : !xls.schan, in>, !xls.schan, out> + } + next (%arg0: !xls.schan, in>, %arg1: !xls.schan, out>, %arg4: i32) zeroinitializer { + %0 = xls.after_all : !xls.token + %tkn_out, %result = xls.sblocking_receive %0, %arg0 : (!xls.token, !xls.schan, in>) -> (!xls.token, tensor<8xi32>) + %1 = xls.after_all : !xls.token + %2 = xls.ssend %1, %result, %arg1 : (!xls.token, tensor<8xi32>, !xls.schan, out>) -> !xls.token + xls.yield %arg4 : i32 + } +} + +// Produces a result from a spawn. +// CHECK: xls.sproc @produce_result(%arg0: !xls.schan, out>) top { +// CHECK: spawns { +// CHECK: %out, %in = xls.schan>("x") +// CHECK: xls.spawn @fn2(%arg0) : !xls.schan, out> +// CHECK: xls.yield +// CHECK: } +// CHECK: next (%arg0: i32) zeroinitializer { +// CHECK: xls.yield %arg0 : i32 +// CHECK: } +// CHECK: } +xls.sproc @produce_result(%arg0: !xls.schan, out>) top { + spawns { + %out, %in = xls.schan>("x") + xls.spawn @fn2(%out) : !xls.schan, out> + xls.yield %in, %arg0 : !xls.schan, in>, !xls.schan, out> + } + next (%arg0: !xls.schan, in>, %arg1: !xls.schan, out>, %arg2: i32) zeroinitializer { + %0 = xls.after_all : !xls.token + %tkn_out, %result = xls.sblocking_receive %0, %arg0 : (!xls.token, !xls.schan, in>) -> (!xls.token, tensor<8xi32>) + %2 = xls.ssend %0, %result, %arg1 : (!xls.token, tensor<8xi32>, !xls.schan, out>) -> !xls.token + xls.yield %arg2 : i32 + } +} + +// Contracts away the interior channel. +// CHECK: xls.sproc @contract_away_interior_channel() top { +// CHECK: spawns { +// CHECK: %out, %in = xls.schan>("x") +// CHECK: xls.spawn @fn2(%out) : !xls.schan, out> +// CHECK: xls.spawn @fn(%in) : !xls.schan, in> +// CHECK: xls.yield +// CHECK: } +// CHECK: next (%arg0: i32) zeroinitializer { +// CHECK: xls.yield %arg0 : i32 +// CHECK: } +// CHECK: } +xls.sproc @contract_away_interior_channel() top { + spawns { + %out, %in = xls.schan>("x") + xls.spawn @fn2(%out) : !xls.schan, out> + %out2, %in2 = xls.schan>("x") + xls.spawn @fn(%in2) : !xls.schan, in> + xls.yield %in, %out2 : !xls.schan, in>, !xls.schan, out> + } + next (%arg0: !xls.schan, in>,%arg1: !xls.schan, out>, %arg2: i32) zeroinitializer { + %0 = xls.after_all : !xls.token + %tkn_out, %result = xls.sblocking_receive %0, %arg0 : (!xls.token, !xls.schan, in>) -> (!xls.token, tensor<8xi32>) + %2 = xls.ssend %0, %result, %arg1 : (!xls.token, tensor<8xi32>, !xls.schan, out>) -> !xls.token + xls.yield %arg2 : i32 + } +} + +// The result from this receive is used twice, so we can't contract away the +// interior channel. +// CHECK: xls.sproc @receive_used_twice() top { +// CHECK: spawns { +// CHECK: %out, %in = xls.schan>("x") +// CHECK: xls.spawn @fn2(%out) : !xls.schan, out> +// CHECK: %out_0, %in_1 = xls.schan>("x") +// CHECK: xls.spawn @fn(%in_1) : !xls.schan, in> +// CHECK: xls.yield %in, %out_0 : !xls.schan, in>, !xls.schan, out> +// CHECK: } +xls.sproc @receive_used_twice() top { + spawns { + %out, %in = xls.schan>("x") + xls.spawn @fn2(%out) : !xls.schan, out> + %out2, %in2 = xls.schan>("x") + xls.spawn @fn(%in2) : !xls.schan, in> + xls.yield %in, %out2 : !xls.schan, in>, !xls.schan, out> + } + next (%arg0: !xls.schan, in>,%arg1: !xls.schan, out>, %arg2: i32) zeroinitializer { + %0 = xls.after_all : !xls.token + %tkn_out, %result = xls.sblocking_receive %0, %arg0 : (!xls.token, !xls.schan, in>) -> (!xls.token, tensor<8xi32>) + %2 = xls.ssend %0, %result, %arg1 : (!xls.token, tensor<8xi32>, !xls.schan, out>) -> !xls.token + %3 = xls.ssend %0, %result, %arg1 : (!xls.token, tensor<8xi32>, !xls.schan, out>) -> !xls.token + xls.yield %arg2 : i32 + } +} + +// The send is predicated, so we can't contract away the interior channel. +// CHECK: xls.sproc @send_predicated() top { +// CHECK: spawns { +// CHECK: %out, %in = xls.schan>("x") +// CHECK: xls.spawn @fn2(%out) : !xls.schan, out> +// CHECK: %out_0, %in_1 = xls.schan>("x") +// CHECK: xls.spawn @fn(%in_1) : !xls.schan, in> +// CHECK: xls.yield %in, %out_0 : !xls.schan, in>, !xls.schan, out> +// CHECK: } +xls.sproc @send_predicated() top { + spawns { + %out, %in = xls.schan>("x") + xls.spawn @fn2(%out) : !xls.schan, out> + %out2, %in2 = xls.schan>("x") + xls.spawn @fn(%in2) : !xls.schan, in> + xls.yield %in, %out2 : !xls.schan, in>, !xls.schan, out> + } + next (%arg0: !xls.schan, in>,%arg1: !xls.schan, out>, %arg2: i32) zeroinitializer { + %0 = xls.after_all : !xls.token + %tkn_out, %result = xls.sblocking_receive %0, %arg0 : (!xls.token, !xls.schan, in>) -> (!xls.token, tensor<8xi32>) + %true = arith.constant 1 : i1 + %2 = xls.ssend %0, %result, %arg1, %true : (!xls.token, tensor<8xi32>, !xls.schan, out>, i1) -> !xls.token + xls.yield %arg2 : i32 + } +} + +// The receive is predicated, so we can't contract away the interior channel. +// CHECK: xls.sproc @recv_predicated() top { +// CHECK: spawns { +// CHECK: %out, %in = xls.schan>("x") +// CHECK: xls.spawn @fn2(%out) : !xls.schan, out> +// CHECK: %out_0, %in_1 = xls.schan>("x") +// CHECK: xls.spawn @fn(%in_1) : !xls.schan, in> +// CHECK: xls.yield %in, %out_0 : !xls.schan, in>, !xls.schan, out> +// CHECK: } +xls.sproc @recv_predicated() top { + spawns { + %out, %in = xls.schan>("x") + xls.spawn @fn2(%out) : !xls.schan, out> + %out2, %in2 = xls.schan>("x") + xls.spawn @fn(%in2) : !xls.schan, in> + xls.yield %in, %out2 : !xls.schan, in>, !xls.schan, out> + } + next (%arg0: !xls.schan, in>,%arg1: !xls.schan, out>, %arg2: i32) zeroinitializer { + %0 = xls.after_all : !xls.token + %true = arith.constant 1 : i1 + %tkn_out, %result = xls.sblocking_receive %0, %arg0, %true : (!xls.token, !xls.schan, in>, i1) -> (!xls.token, tensor<8xi32>) + %2 = xls.ssend %0, %result, %arg1 : (!xls.token, tensor<8xi32>, !xls.schan, out>) -> !xls.token + xls.yield %arg2 : i32 + } +} diff --git a/xls/contrib/mlir/transforms/optimize_spawns.cc b/xls/contrib/mlir/transforms/optimize_spawns.cc new file mode 100644 index 0000000000..3555833e31 --- /dev/null +++ b/xls/contrib/mlir/transforms/optimize_spawns.cc @@ -0,0 +1,122 @@ +// Copyright 2024 The XLS Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include + +#include "mlir/include/mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/include/mlir/IR/PatternMatch.h" +#include "mlir/include/mlir/IR/Value.h" +#include "mlir/include/mlir/IR/Visitors.h" +#include "mlir/include/mlir/Pass/Pass.h" // IWYU pragma: keep +#include "mlir/include/mlir/Support/LLVM.h" +#include "mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "xls/contrib/mlir/IR/xls_ops.h" + +namespace mlir::xls { + +#define GEN_PASS_DEF_OPTIMIZESPAWNSPASS +#include "xls/contrib/mlir/transforms/passes.h.inc" + +namespace { + +class SendOfBlockingReceiveOp : public OpRewritePattern { + public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(SSendOp send, + PatternRewriter& rewriter) const override { + auto recv = send.getData().getDefiningOp(); + // We assume that nobody uses the token output of the send or the token + // output of the receive. + if (!recv || recv.getPredicate() || send.getPredicate() || + !recv.getResult().hasOneUse()) { + return rewriter.notifyMatchFailure( + send, "not a recv, or predicated, or recv has multiple uses"); + } + // Only proceed if the token is trivial for the send, or is the output token + // of the receive. + if (!isTrivialToken(send.getTkn()) && send.getTkn() != recv.getTknOut()) { + return rewriter.notifyMatchFailure(send, "send token is not trivial"); + } + + // It is safe to remove the receive and send. + SprocOp sproc = send->getParentOfType(); + int64_t sendChanIdx = cast(send.getChannel()).getArgNumber(); + int64_t recvChanIdx = cast(recv.getChannel()).getArgNumber(); + Value sendChan = sproc.getYieldedChannels()[sendChanIdx]; + Value recvChan = sproc.getYieldedChannels()[recvChanIdx]; + + if (auto sendChanOp = sendChan.getDefiningOp()) { + // The send channel is an interior channel, so replace uses of its "in" + // port with the recv channel (recv can either be interior or argument). + Value sendChanReceiver = sendChanOp.getIn(); + rewriter.replaceUsesWithIf(sendChanReceiver, recvChan, + [&](OpOperand& opOperand) { + return !isa(opOperand.getOwner()); + }); + } else if (auto recvChanOp = recvChan.getDefiningOp()) { + // The recv channel is an interior channel, so replace uses of its "out" + // port with the send channel (send can either be interior or a result). + Value recvChanSender = recvChanOp.getOut(); + rewriter.replaceUsesWithIf(recvChanSender, sendChan, + [&](OpOperand& opOperand) { + return !isa(opOperand.getOwner()); + }); + } else { + // The recv channel is an argument and the send channel is a result. This + // needs a send/recv pair so we can't optimize this. + return failure(); + } + + rewriter.replaceAllUsesWith(recv.getTknOut(), recv.getTkn()); + rewriter.replaceAllUsesWith(send.getResult(), send.getTkn()); + rewriter.eraseOp(send); + rewriter.eraseOp(recv); + + // When erasing arguments, we need to erase them from highest to lowest so + // that the indices are not invalidated. + SmallVector channelIndices = {sendChanIdx, recvChanIdx}; + std::sort(channelIndices.begin(), channelIndices.end()); + std::reverse(channelIndices.begin(), channelIndices.end()); + + for (int64_t chanIdx : channelIndices) { + sproc.getNext().eraseArgument(chanIdx); + sproc.getSpawns().front().getTerminator()->eraseOperand(chanIdx); + } + return success(); + } + + bool isTrivialToken(Value token) const { + AfterAllOp afterAll = token.getDefiningOp(); + return afterAll && afterAll.getOperands().empty(); + } +}; + +class OptimizeSpawnsPass + : public impl::OptimizeSpawnsPassBase { + public: + void runOnOperation() override { + RewritePatternSet patterns(&getContext()); + patterns.add(&getContext()); + if (failed(applyPatternsAndFoldGreedily(getOperation().getNext(), + std::move(patterns)))) { + signalPassFailure(); + } + } +}; + +} // namespace +} // namespace mlir::xls diff --git a/xls/contrib/mlir/transforms/passes.td b/xls/contrib/mlir/transforms/passes.td index 314162d988..f859f1ca71 100644 --- a/xls/contrib/mlir/transforms/passes.td +++ b/xls/contrib/mlir/transforms/passes.td @@ -147,4 +147,46 @@ def ProcifyLoopsPass : Pass<"procify-loops", "mlir::ModuleOp"> { ]; } +def OptimizeSpawnsPass : Pass<"optimize-spawns", "SprocOp"> { + let summary = "Optimizes spawn ops"; + let description = [{ + Optimizes away sends and receives to channels that are interior to multiple + spawns, or argument or result channels. For example: + + ``` + spawns { + chan X, Y; + spawn @y(X, Y); + chan A, B; + spawn @x(A, B); + yield X, Y, A, B; + } + next (A, B) { + ssend X; + %0 = sblocking_recv Y; + ssend %0 to A; // Pointless recv/send pair, could be removed. + sblocking_recv B; + ... use B ... + } + ``` + + into: + + ``` + spawns { + chan X, Y; + spawn @y(X, Y); + chan B; + spawn @x(Y, B); // Channel A is removed and replaced by Y. + yield X, B; + } + next (A, B) { + ssend X; + sblocking_recv B; + ... use B ... + } + ``` + }]; +} + #endif // MLIR_XLS_TRANSFORMS_PASSES diff --git a/xls/contrib/mlir/transforms/xls_lower.cc b/xls/contrib/mlir/transforms/xls_lower.cc index eec7e0bd35..78a56c97e9 100644 --- a/xls/contrib/mlir/transforms/xls_lower.cc +++ b/xls/contrib/mlir/transforms/xls_lower.cc @@ -17,6 +17,7 @@ #include "mlir/include/mlir/Pass/PassManager.h" #include "mlir/include/mlir/Pass/PassRegistry.h" #include "mlir/include/mlir/Transforms/Passes.h" +#include "xls/contrib/mlir/IR/xls_ops.h" #include "xls/contrib/mlir/transforms/passes.h" namespace mlir::xls { @@ -26,6 +27,7 @@ void XlsLowerPassPipeline(OpPassManager& pm, pm.addPass(createProcifyLoopsPass({ .apply_by_default = options.procify_loops_apply_by_default, })); + pm.addNestedPass(createOptimizeSpawnsPass()); pm.addPass(mlir::createCanonicalizerPass()); pm.addPass(createProcElaborationPass()); if (options.instantiate_eprocs) {