Skip to content

Commit

Permalink
Adding --iree-scheduling-initialization-mode= flag. (#19778)
Browse files Browse the repository at this point in the history
This allows for choosing whether initializers return immediately with
asynchronous work still pending or if they block and wait prior to
returning. Users benchmarking will want to use synchronous mode while
users wanting to overlap other work with initialization will want
asynchronous mode. Since all existing frameworks operate with
synchronous initialization the default is changed to that.

For some spooky reason (#19795) this causes a few more onnx op tests to
fail in addition to existing ones that were already failing. They've
been xfailed for now because I cannot figure out what's going on or
reproduce the issue.

Fixes #19770.
  • Loading branch information
benvanik authored Jan 24, 2025
1 parent c52eb68 commit bbe7f5c
Show file tree
Hide file tree
Showing 20 changed files with 341 additions and 47 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -484,6 +484,30 @@ class ValueResourceUsage : public AbstractResourceUsage<DFX::ValueElement> {
getState() ^= resultUsage.getState();
}
})
.Case<IREE::Stream::AsyncExecuteOp, IREE::Stream::AsyncConcurrentOp>(
[&](auto op) {
IREE::Stream::AsyncConcurrentOp c;
// Take on the state from the internal usage.
for (auto yieldOp :
op.getClosureBodyRegion()
.template getOps<IREE::Stream::YieldOp>()) {
auto &yieldUsage = solver.getElementFor<ValueResourceUsage>(
*this,
Position::forValue(
yieldOp.getOperand(result.getResultNumber())),
DFX::Resolution::REQUIRED);
getState() ^= yieldUsage.getState();
}
// If the result is passed through as a tied operand then also
// inherit the original state.
auto tiedOperand = op.getTiedResultOperand(result);
if (tiedOperand) {
auto &tiedUsage = solver.getElementFor<ValueResourceUsage>(
*this, Position::forValue(tiedOperand),
DFX::Resolution::REQUIRED);
getState() ^= tiedUsage.getState();
}
})
.Default([&](Operation *op) {});
}

Expand Down Expand Up @@ -805,6 +829,30 @@ class ValueResourceUsage : public AbstractResourceUsage<DFX::ValueElement> {
getState() ^= resultUsage.getState();
}
})
.Case<IREE::Stream::AsyncExecuteOp, IREE::Stream::AsyncConcurrentOp>(
[&](auto op) {
// Take on the traits of all ops within the execution region that
// use the value and handle ties if needed.
auto &operandUsage = solver.getElementFor<ValueResourceUsage>(
*this,
Position::forValue(
op.getClosureBodyRegion().getArgument(operandIdx)),
DFX::Resolution::REQUIRED);
getState() ^= operandUsage.getState();
for (auto result : op.getOperandTiedResults(operandIdx)) {
auto &resultUsage = solver.getElementFor<ValueResourceUsage>(
*this, Position::forValue(result),
DFX::Resolution::REQUIRED);
getState() ^= resultUsage.getState();
}
})
.Case([&](IREE::Stream::YieldOp op) {
// Take on the traits of the result of the parent operation.
Value result = op->getParentOp()->getResult(operandIdx);
auto &resultUsage = solver.getElementFor<ValueResourceUsage>(
*this, Position::forValue(result), DFX::Resolution::REQUIRED);
getState() ^= resultUsage.getState();
})
.Default([&](Operation *op) {});
}

Expand Down
41 changes: 35 additions & 6 deletions compiler/src/iree/compiler/Dialect/Stream/IR/StreamOpFolders.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -461,9 +461,11 @@ struct ChainDependentAwaits : public OpRewritePattern<Op> {
for (auto operand : llvm::enumerate(op.getResourceOperands())) {
if (auto awaitOp =
operand.value().template getDefiningOp<TimepointAwaitOp>()) {
newTimepoints.push_back(awaitOp.getAwaitTimepoint());
replacements.push_back(std::make_pair(
operand.index(), awaitOp.getTiedResultOperand(operand.value())));
if (!awaitOp.getSync()) {
newTimepoints.push_back(awaitOp.getAwaitTimepoint());
replacements.push_back(std::make_pair(
operand.index(), awaitOp.getTiedResultOperand(operand.value())));
}
}
}
if (replacements.empty())
Expand Down Expand Up @@ -3050,7 +3052,9 @@ findSourceAwaitOp(Value resource) {
baseResource.getDefiningOp())) {
if (auto awaitOp = dyn_cast<IREE::Stream::TimepointAwaitOp>(
baseResource.getDefiningOp())) {
return {awaitOp, baseResource};
if (!awaitOp.getSync()) {
return {awaitOp, baseResource};
}
}
auto tiedValue = definingOp.getTiedResultOperand(baseResource);
if (!tiedValue)
Expand Down Expand Up @@ -3141,6 +3145,11 @@ struct SinkAwaitToFirstConsumer : public OpRewritePattern<TimepointAwaitOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(TimepointAwaitOp op,
PatternRewriter &rewriter) const override {
// Don't move sync points as they may be implicitly guarding execution.
if (op.getSync()) {
return rewriter.notifyMatchFailure(op, "sync awaits cannot be moved");
}

// TODO(benvanik): amortize this dominance calculation.
DominanceInfo domInfo(op->getParentOp());

Expand Down Expand Up @@ -3197,6 +3206,7 @@ struct SinkSubviewsAcrossAwaits : public OpRewritePattern<TimepointAwaitOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(TimepointAwaitOp op,
PatternRewriter &rewriter) const override {
rewriter.setInsertionPointAfter(op);
rewriter.startOpModification(op);
bool didChange = false;
for (auto operand : llvm::enumerate(op.getResourceOperands())) {
Expand Down Expand Up @@ -3276,7 +3286,7 @@ struct GroupAwaitsByTimepoint : public OpRewritePattern<TimepointAwaitOp> {
if (dominanceInfo.dominates(use.getOwner(), op))
continue;
auto awaitOp = dyn_cast<TimepointAwaitOp>(use.getOwner());
if (!awaitOp)
if (!awaitOp || awaitOp.getSync())
continue;
// Ensure all dependencies of the await op are available.
if (!areAllOperandsDefinedBy(awaitOp, op, dominanceInfo)) {
Expand Down Expand Up @@ -3351,6 +3361,7 @@ struct FoldDuplicateAwaitResources : public OpRewritePattern<TimepointAwaitOp> {
// Create replacement op with deduped operands/results.
auto newOp = rewriter.create<IREE::Stream::TimepointAwaitOp>(
op.getLoc(), newOperands, newOperandSizes, op.getAwaitTimepoint());
newOp.setSync(op.getSync());

// Replace all duplicate results with the base results.
for (auto &replacement : replacements) {
Expand All @@ -3363,6 +3374,24 @@ struct FoldDuplicateAwaitResources : public OpRewritePattern<TimepointAwaitOp> {
}
};

struct ElideUnusedTimepointAwait : public OpRewritePattern<TimepointAwaitOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(TimepointAwaitOp op,
PatternRewriter &rewriter) const override {
// If there are any uses the await is required to associate the timepoint.
if (!op.use_empty()) {
return failure();
}
// If the await is a sync point then we cannot elide it even if it has no
// uses.
if (op.getSync()) {
return rewriter.notifyMatchFailure(op, "sync ops cannot be elided");
}
rewriter.eraseOp(op);
return success();
}
};

} // namespace

void TimepointAwaitOp::getCanonicalizationPatterns(RewritePatternSet &results,
Expand All @@ -3373,7 +3402,7 @@ void TimepointAwaitOp::getCanonicalizationPatterns(RewritePatternSet &results,
results.insert<SinkSubviewsAcrossAwaits>(context);
results.insert<GroupAwaitsByTimepoint>(context);
results.insert<FoldDuplicateAwaitResources>(context);
results.insert<ElideUnusedOp<TimepointAwaitOp>>(context);
results.insert<ElideUnusedTimepointAwait>(context);
}

//===----------------------------------------------------------------------===//
Expand Down
4 changes: 3 additions & 1 deletion compiler/src/iree/compiler/Dialect/Stream/IR/StreamOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -3793,7 +3793,8 @@ def Stream_TimepointAwaitOp : Stream_PureOp<"timepoint.await", [
Stream_StagingResource,
]>>:$resource_operands,
Variadic<Stream_Size>:$resource_operand_sizes,
Stream_Timepoint:$await_timepoint
Stream_Timepoint:$await_timepoint,
UnitAttr:$sync
);
let results = (outs
Variadic<AnyTypeOf<[
Expand All @@ -3803,6 +3804,7 @@ def Stream_TimepointAwaitOp : Stream_PureOp<"timepoint.await", [
);

let assemblyFormat = [{
(`sync` $sync^)?
$await_timepoint `=` `` `>`
$resource_operands `:`
custom<ShapedTypeList>(type($resource_operands),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ iree_compiler_cc_library(
"ScheduleExecution.cpp",
"SpecializeDispatches.cpp",
"SpecializeEncodings.cpp",
"SyncInitializers.cpp",
"VerifyAffinities.cpp",
"VerifyAsyncAccessRanges.cpp",
"VerifyLowerings.cpp",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ iree_cc_library(
"ScheduleExecution.cpp"
"SpecializeDispatches.cpp"
"SpecializeEncodings.cpp"
"SyncInitializers.cpp"
"VerifyAffinities.cpp"
"VerifyAsyncAccessRanges.cpp"
"VerifyLowerings.cpp"
Expand Down
46 changes: 33 additions & 13 deletions compiler/src/iree/compiler/Dialect/Stream/Transforms/Passes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,10 +37,12 @@ using FunctionLikeNest =
MultiOpNest<func::FuncOp, IREE::Util::InitializerOp, IREE::Util::FuncOp>;

//===----------------------------------------------------------------------===//
// Utilities
// --iree-stream-cleanup-pipeline
//===----------------------------------------------------------------------===//

static void addCleanupPatterns(OpPassManager &passManager) {
static void buildStreamCleanupPassPipeline(
OpPassManager &passManager,
const IREE::Stream::TransformOptions &transformOptions) {
FunctionLikeNest(passManager)
// Standard MLIR cleanup.
.addPass(mlir::createCanonicalizerPass)
Expand Down Expand Up @@ -84,7 +86,7 @@ void buildStreamTensorPassPipeline(OpPassManager &passManager,

// Cleanup the program prior to outlining constants in case there is
// propagation or fusion that needs to happen first.
addCleanupPatterns(passManager);
buildStreamCleanupPassPipeline(passManager, transformOptions);

//----------------------------------------------------------------------------
// Conversion
Expand Down Expand Up @@ -114,7 +116,7 @@ void buildStreamTensorPassPipeline(OpPassManager &passManager,
passManager.addPass(mlir::createInlinerPass());

// Cleanup globals that were created during conversion.
addCleanupPatterns(passManager);
buildStreamCleanupPassPipeline(passManager, transformOptions);

// Bring all initializers together so that we can schedule them.
passManager.addPass(IREE::Util::createCombineInitializersPass());
Expand Down Expand Up @@ -160,7 +162,7 @@ void buildStreamAsyncPassPipeline(OpPassManager &passManager,
passManager.addNestedPass<IREE::Stream::ExecutableOp>(
IREE::Stream::createEncodeDeviceTensorsPass());

addCleanupPatterns(passManager);
buildStreamCleanupPassPipeline(passManager, transformOptions);

// Everything must now be in stream.async.* form but we don't yet have
// lifetime assigned.
Expand All @@ -186,7 +188,7 @@ void buildStreamAsyncPassPipeline(OpPassManager &passManager,
// change and it makes the IR cleaner.
passManager.addPass(IREE::Stream::createRefineUsagePass());

addCleanupPatterns(passManager);
buildStreamCleanupPassPipeline(passManager, transformOptions);

// Verify all stream.async.* op access ranges that we can by taking advantage
// of statically available information or that which we can infer from data
Expand All @@ -207,6 +209,13 @@ void buildStreamAsyncPassPipeline(OpPassManager &passManager,
// Group concurrently executable work into waves.
.addPass(IREE::Stream::createScheduleConcurrencyPass);

// When synchronous initialization is requested we need to separate any work
// behind a timepoint in the initializer from the consumers of that timepoint.
if (transformOptions.initializationMode ==
IREE::Stream::InitializationMode::Synchronous) {
passManager.addPass(IREE::Stream::createSyncInitializersPass());
}

// Materialize timepoints across the entire module. This simplifies scheduling
// of the timeline as we can shake the IR and see what timepoints we still
// have left.
Expand All @@ -217,7 +226,7 @@ void buildStreamAsyncPassPipeline(OpPassManager &passManager,
// for partitioning/placement before turning them into opaque dispatches.
passManager.addPass(IREE::Stream::createMaterializeBuiltinsPass());

addCleanupPatterns(passManager);
buildStreamCleanupPassPipeline(passManager, transformOptions);

// Everything must now be in stream.async.* form.
passManager.addPass(IREE::Stream::createVerifyLoweringToAsyncPass());
Expand Down Expand Up @@ -245,13 +254,17 @@ void buildStreamCmdPassPipeline(OpPassManager &passManager,
// Layout packed slices to emit the arithmetic required for all resource
// offsets. This enables us to propagate the subviews across the program
// below.
.addPass(IREE::Stream::createLayoutSlicesPass);
.addPass(IREE::Stream::createLayoutSlicesPass)

// Apply canonicalization patterns to clean up subview ops prior to
// propagating subranges.
.addPass(mlir::createCanonicalizerPass);

// Propagate subviews throughout the program to unify resource storage access.
// After propagation many resource SSA values can be deduped or folded by the
// cleanup patterns.
passManager.addPass(IREE::Util::createPropagateSubrangesPass());
addCleanupPatterns(passManager);
buildStreamCleanupPassPipeline(passManager, transformOptions);

// TODO(benvanik): outline streams (ala dispatch regions). Note that we may
// want to do this earlier to enable better deduplication but that makes the
Expand All @@ -270,7 +283,7 @@ void buildStreamOptimizationPassPipeline(
OpPassManager &passManager, const TransformOptions &transformOptions) {
// Forming streams involves a fair amount of subgraph stitching, which can
// cause duplication. Run CSE to collapse.
addCleanupPatterns(passManager);
buildStreamCleanupPassPipeline(passManager, transformOptions);

// If any scf ops crept in we get rid of them here. We should be able to
// support them all the way through the stream dialect but some passes are not
Expand All @@ -290,7 +303,7 @@ void buildStreamOptimizationPassPipeline(
OpPassManager ipoPipeline(mlir::ModuleOp::getOperationName());

// IPO and other cleanups.
addCleanupPatterns(ipoPipeline);
buildStreamCleanupPassPipeline(ipoPipeline, transformOptions);

// TODO(#9747): elide timepoints that are know-reached due to host
// synchronization via stream.timepoint.await.
Expand Down Expand Up @@ -333,7 +346,7 @@ void buildStreamOptimizationPassPipeline(

// Folding operands requires that canonicalization/CSE folds the inputs that
// we check for.
addCleanupPatterns(passManager);
buildStreamCleanupPassPipeline(passManager, transformOptions);
passManager.addPass(IREE::Stream::createFoldUniformOperandsPass());

// Only want to specialize after we've added all the operands we need above.
Expand Down Expand Up @@ -383,7 +396,7 @@ void buildStreamTransformPassPipeline(
//----------------------------------------------------------------------------

// Final cleanup after we optimize dispatches and fuse operands and bindings.
addCleanupPatterns(passManager);
buildStreamCleanupPassPipeline(passManager, transformOptions);

// Symbol DCE any remaining variables/functions that are now no longer
// required.
Expand All @@ -404,6 +417,13 @@ void registerStreamPasses() {
registerPasses();

// Pipelines.
PassPipelineRegistration<TransformOptions> cleanupPassPipeline(
"iree-stream-cleanup-pipeline",
"Runs the cleanup passes that are performed between stages of the full "
"stream pipeline.",
[](OpPassManager &passManager, const TransformOptions &transformOptions) {
buildStreamCleanupPassPipeline(passManager, transformOptions);
});
PassPipelineRegistration<TransformOptions> tensorPassPipeline(
"iree-stream-tensor-transformation-pipeline",
"Lowers source dialects into stream.tensor.* IR.",
Expand Down
33 changes: 31 additions & 2 deletions compiler/src/iree/compiler/Dialect/Stream/Transforms/Passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,22 @@ namespace mlir::iree_compiler::IREE::Stream {
// Pipelines
//===----------------------------------------------------------------------===//

// TODO(benvanik): find a way to share this with IREEVM.h w/o circular deps.
// TODO(benvanik): find a way to share option enums with the top-level Options.h
// w/o circular deps.

// Defines the behavior of initialization.
enum class InitializationMode {
// Synchronously initialize all parameters and globals prior to returning
// from the module initializer.
Synchronous = 0,
// Asynchronously initialize all parameters and globals and return
// immediately from the module initializer without waiting for them to
// complete. Subsequent invocations will queue waiting for any dependencies
// they have on the initialized values.
Asynchronous = 1,
};

// TODO(benvanik): find a way to share this with Options.h w/o circular deps.
// Defines the output format of a dump pass.
enum class DumpOutputFormat {
// Dumping disabled.
Expand All @@ -40,7 +55,21 @@ enum class DumpOutputFormat {
};

struct TransformOptions : public PassPipelineOptions<TransformOptions> {
// TODO(benvanik): options for async/sync overrides.
Option<InitializationMode> initializationMode{
*this,
"initialization-mode",
llvm::cl::desc(
"Specifies the initialization mode for parameters and globals."),
llvm::cl::init(InitializationMode::Synchronous),
llvm::cl::values(
clEnumValN(InitializationMode::Synchronous, "sync",
"Synchronously initialize all parameters and globals "
"prior to returning from the module initializer."),
clEnumValN(InitializationMode::Asynchronous, "async",
"Asynchronously initialize all parameters and globals and "
"return immediately from the module initializer without "
"waiting for them to complete.")),
};

Option<bool> optimizeBindings{
*this,
Expand Down
Loading

0 comments on commit bbe7f5c

Please sign in to comment.