Skip to content

Commit

Permalink
Improving global folding and IPO for immutable globals. (iree-org#18066)
Browse files Browse the repository at this point in the history
Parameters and constants end up as globals that previously were not
getting marked as immutable unless the globals were default initialized.
Now the GlobalTable tracks whether a global is exclusively stored within
initializers (or functions only called from initializers) in order to
mark them as immutable. IPO was updated to support propagating uniform
immutable global loads across call edges as if they were constants (as
they effectively are just constants stored on the global scope).

Required for iree-org#17875 (to avoid treating constants/parameters as dynamic
binding table values).
  • Loading branch information
benvanik authored Jul 31, 2024
1 parent d8d1407 commit e792c32
Show file tree
Hide file tree
Showing 8 changed files with 231 additions and 19 deletions.
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
// RUN: iree-opt --split-input-file --iree-flow-transformation-pipeline --iree-flow-export-benchmark-funcs --verify-diagnostics %s | FileCheck %s
// RUN: iree-opt --split-input-file --iree-flow-export-benchmark-funcs-pass --verify-diagnostics %s | FileCheck %s

// Basic usage from the `--iree-native-bindings-support` flag.

// CHECK-LABEL: func private @simpleMul
util.func public @simpleMul(%arg0: !hal.buffer_view, %arg1: !hal.buffer_view) -> !hal.buffer_view attributes {iree.abi.stub, iree.module.export} {
util.func public @simpleMul(%arg0: !hal.buffer_view, %arg1: !hal.buffer_view) -> !hal.buffer_view {
%0 = hal.tensor.import %arg0 : !hal.buffer_view -> tensor<4xf32>
%1 = hal.tensor.import %arg1 : !hal.buffer_view -> tensor<4xf32>
%2 = arith.mulf %0, %1 : tensor<4xf32>
Expand Down Expand Up @@ -41,8 +41,8 @@ util.func public @while(%start: i32, %bound: i32) -> i32 {
// CHECK: util.global private @[[GLOBAL_ARG1:.+]] {{{.+}}} = 0 : i32

// CHECK: util.func public @while_benchmark()
// CHECK-DAG: %[[ARG0:.+]] = util.global.load immutable @[[GLOBAL_ARG0]] : i32
// CHECK-DAG: %[[ARG1:.+]] = util.global.load immutable @[[GLOBAL_ARG1]] : i32
// CHECK-DAG: %[[ARG0:.+]] = util.global.load @[[GLOBAL_ARG0]] : i32
// CHECK-DAG: %[[ARG1:.+]] = util.global.load @[[GLOBAL_ARG1]] : i32
// CHECK: %[[RET0:.+]] = util.call @while(%[[ARG0]], %[[ARG1]])
// CHECK: util.optimization_barrier %[[RET0]] : i32
// CHECK: util.return
Expand Down
61 changes: 60 additions & 1 deletion compiler/src/iree/compiler/Dialect/Util/Analysis/GlobalTable.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,57 @@

#include "iree/compiler/Dialect/Util/Analysis/GlobalTable.h"

#include "mlir/Analysis/CallGraph.h"

namespace mlir::iree_compiler::IREE::Util {

// Returns a set of all top-level callable ops that are externally reachable.
// Callables only reachable from initializers are excluded.
static DenseSet<Operation *>
calculateExternallyReachableOps(ModuleOp moduleOp) {
DenseSet<Operation *> externallyReachableOps;

// Expensive; we want to avoid this unless the call graph changes.
CallGraph callGraph(moduleOp);

SetVector<CallGraphNode *> worklist;
worklist.insert(callGraph.begin(), callGraph.end());
while (!worklist.empty()) {
auto *node = worklist.pop_back_val();
if (node->isExternal()) {
// Skip declarations.
continue;
}
auto *callableOp = node->getCallableRegion()->getParentOp();
if (isa<IREE::Util::InitializerOpInterface>(callableOp)) {
// Initializers are never externally reachable.
continue;
}
bool isExternallyReachable = externallyReachableOps.contains(callableOp);
if (auto funcOp = dyn_cast<FunctionOpInterface>(callableOp)) {
// Public functions exported on the module are externally reachable.
isExternallyReachable |= funcOp.isPublic();
}
if (isExternallyReachable) {
// Insert into the set of reachable ops and also any outgoing calls.
// Queue up the edges in the worklist for further processing.
externallyReachableOps.insert(callableOp);
for (auto outgoingEdge : *node) {
auto *calleeNode = outgoingEdge.getTarget();
if (!calleeNode->isExternal()) {
externallyReachableOps.insert(
calleeNode->getCallableRegion()->getParentOp());
worklist.insert(outgoingEdge.getTarget());
}
}
}
}

return externallyReachableOps;
}

GlobalTable::GlobalTable(mlir::ModuleOp moduleOp) : moduleOp(moduleOp) {
rebuild();
externallyReachableOps = calculateExternallyReachableOps(moduleOp);
}

void GlobalTable::rebuild() {
Expand Down Expand Up @@ -46,6 +93,18 @@ void GlobalTable::rebuild() {
}
}
}

for (auto &[globalName, global] : globalMap) {
bool anyNonInitializerStores = false;
for (auto storeOp : global.storeOps) {
auto callableOp = storeOp->getParentOfType<CallableOpInterface>();
if (externallyReachableOps.contains(callableOp)) {
anyNonInitializerStores = true;
break;
}
}
global.onlyInitialized = !anyNonInitializerStores;
}
}

Global &GlobalTable::lookup(StringRef globalName) {
Expand Down
19 changes: 17 additions & 2 deletions compiler/src/iree/compiler/Dialect/Util/Analysis/GlobalTable.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,10 @@ struct Global {
// currently have any input programs that require doing so.
bool isIndirect = false;

// True if all stores to the global are performed within initializers or calls
// only reachable from initializers.
bool onlyInitialized = false;

// All util.global.load ops referencing the global.
SmallVector<IREE::Util::GlobalLoadOpInterface> loadOps;
// All util.global.store ops referencing the global.
Expand Down Expand Up @@ -81,12 +85,20 @@ enum class GlobalAction {
// A constructed table of analyzed globals in a module with some utilities for
// manipulating them. This is designed for simple uses and more advanced
// analysis should be performed with an Explorer or DFX.
//
// The global table is not built on creation and `rebuild` must be called before
// querying it.
struct GlobalTable {
GlobalTable() = delete;
explicit GlobalTable(mlir::ModuleOp moduleOp);

MLIRContext *getContext() { return moduleOp.getContext(); }

// Rebuilds the global table.
// Must be called if the table is to be used after any globals or operations
// on globals have changed.
void rebuild();

// Total number of globals in the module.
size_t size() const { return globalOrder.size(); }

Expand Down Expand Up @@ -114,10 +126,13 @@ struct GlobalTable {
void eraseGlobal(StringRef globalName);

private:
void rebuild();

// Module under analysis.
mlir::ModuleOp moduleOp;

// Top-level callables that are externally reachable.
// Excludes initializers or any callable only reachable from initializers.
DenseSet<Operation *> externallyReachableOps;

// All globals in the order they are declared by symbol name.
SmallVector<StringRef> globalOrder;
// A map of global symbol names to analysis results.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,7 @@ static bool updateGlobalImmutability(GlobalTable &globalTable) {
return globalTable.forEach([&](Global &global) {
if (!global.isCandidate()) {
return GlobalAction::PRESERVE;
} else if (!global.storeOps.empty()) {
} else if (!global.storeOps.empty() && !global.onlyInitialized) {
return GlobalAction::PRESERVE;
}
bool didChangeAny = global.op.isGlobalMutable() != false;
Expand Down Expand Up @@ -365,28 +365,29 @@ class FoldGlobalsPass : public FoldGlobalsBase<FoldGlobalsPass> {
void runOnOperation() override {
auto *context = &getContext();
RewritePatternSet patterns(context);

for (auto *dialect : context->getLoadedDialects()) {
dialect->getCanonicalizationPatterns(patterns);
}
for (auto op : context->getRegisteredOperations()) {
op.getCanonicalizationPatterns(patterns, context);
}

FrozenRewritePatternSet frozenPatterns(std::move(patterns));

auto moduleOp = getOperation();
beforeFoldingGlobals =
count(moduleOp.getOps<IREE::Util::GlobalOpInterface>());
GlobalTable globalTable(moduleOp);
beforeFoldingGlobals = globalTable.size();
for (int i = 0; i < 10; ++i) {
// TODO(benvanik): determine if we need this expensive folding.
if (failed(applyPatternsAndFoldGreedily(moduleOp, frozenPatterns))) {
signalPassFailure();
return;
}

GlobalTable globalTable(moduleOp);
bool didChange = false;

// Rebuild the global table after potential pattern changes.
globalTable.rebuild();

LLVM_DEBUG(llvm::dbgs() << "==== inlineConstantGlobalStores ====\n");
if (inlineConstantGlobalStores(globalTable)) {
LLVM_DEBUG(moduleOp.dump());
Expand Down Expand Up @@ -424,6 +425,7 @@ class FoldGlobalsPass : public FoldGlobalsBase<FoldGlobalsPass> {
}

if (!didChange) {
// No changes; complete fixed-point iteration.
break;
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ class FuseGlobalsPass : public FuseGlobalsBase<FuseGlobalsPass> {
auto moduleOp = getOperation();

GlobalTable globalTable(moduleOp);
globalTable.rebuild();

// Build a map of global symbol to a bitvector indicating which globals are
// stored with the same values in all instances.
Expand Down
61 changes: 55 additions & 6 deletions compiler/src/iree/compiler/Dialect/Util/Transforms/IPO.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ struct FuncAnalysis {
// Which args are uniform from all call sites.
BitVector callerUniformArgs;
// Values for each arg if they are uniformly constant at all call sites.
// May be any constant attribute or an immutable global symbol ref.
SmallVector<LocAttr> callerUniformArgValues;
// Uniform call operand index -> deduplicated index.
// Base/non-duplicated values will be identity.
Expand All @@ -69,6 +70,7 @@ struct FuncAnalysis {
// Which results are uniform from all return sites in the function.
BitVector calleeUniformResults;
// Values for each result if they are uniformly constant at all return sites.
// May be any constant attribute or an immutable global symbol ref.
SmallVector<LocAttr> calleeUniformResultValues;
// Uniform callee return operand index -> deduplicated index.
// Base/non-duplicated values will be identity.
Expand All @@ -94,9 +96,13 @@ struct FuncAnalysis {
os << "dupe(%arg" << callerUniformArgDupeMap[i] << ") ";
}
os << argTypes[i] << " ";
if (callerUniformArgValues[i]) {
os << "constant = ";
callerUniformArgValues[i].attr.print(os);
if (auto constant = callerUniformArgValues[i]) {
if (isa<SymbolRefAttr>(constant.attr)) {
os << "immutable global = ";
} else {
os << "constant = ";
}
constant.attr.print(os);
}
os << "\n";
}
Expand All @@ -113,9 +119,13 @@ struct FuncAnalysis {
os << "pass(%arg" << passthroughResultArgs[i] << ") ";
}
os << resultTypes[i] << " ";
if (calleeUniformResultValues[i]) {
os << "constant = ";
calleeUniformResultValues[i].attr.print(os);
if (auto constant = calleeUniformResultValues[i]) {
if (isa<SymbolRefAttr>(constant.attr)) {
os << "immutable global = ";
} else {
os << "constant = ";
}
constant.attr.print(os);
}
os << "\n";
}
Expand All @@ -128,6 +138,17 @@ struct FuncAnalysis {
}
};

// Returns a global symbol ref if the value is loaded from an immutable global.
static SymbolRefAttr matchImmutableGlobalLoad(Value value) {
if (auto loadOp = dyn_cast_if_present<IREE::Util::GlobalLoadOpInterface>(
value.getDefiningOp())) {
if (loadOp.isGlobalImmutable()) {
return loadOp.getGlobalAttr();
}
}
return {};
}

// Note that the analysis results may be incomplete.
static FuncAnalysis analyzeFuncOp(IREE::Util::FuncOp funcOp,
Explorer &explorer) {
Expand Down Expand Up @@ -184,6 +205,12 @@ static FuncAnalysis analyzeFuncOp(IREE::Util::FuncOp funcOp,
value.getType(),
constantValue,
};
} else if (auto globalRef = matchImmutableGlobalLoad(value)) {
analysis.calleeUniformResultValues[i] = {
value.getLoc(),
value.getType(),
globalRef,
};
}

// Check to see if the value returned is the same as previously seen.
Expand Down Expand Up @@ -245,6 +272,20 @@ static FuncAnalysis analyzeFuncOp(IREE::Util::FuncOp funcOp,
// Value constant has changed from prior calls: mark non-uniform.
analysis.callerUniformArgs.reset(i);
}
} else if (auto globalRef = matchImmutableGlobalLoad(value)) {
if (!seenArgAttrs[i]) {
// First call site with a constant or immutable global: stash so we
// can inline it if it's uniform.
seenArgAttrs[i] = globalRef;
analysis.callerUniformArgValues[i] = {
value.getLoc(),
value.getType(),
globalRef,
};
} else if (seenArgAttrs[i] != globalRef) {
// Value constant has changed from prior calls: mark non-uniform.
analysis.callerUniformArgs.reset(i);
}
} else {
// Check to see if the value is the same as previously seen.
// This will ensure that across calling functions we set non-uniform
Expand Down Expand Up @@ -367,6 +408,14 @@ static void replaceValueWithConstant(Value value, LocAttr constantValue,
OpBuilder &builder) {
Operation *op = nullptr;

// Immutable global loads are represented as constant symbol refs.
if (auto globalRef = dyn_cast<SymbolRefAttr>(constantValue.attr)) {
op = builder.create<IREE::Util::GlobalLoadOp>(
constantValue.loc.value(), constantValue.type,
globalRef.getLeafReference().getValue(),
/*is_immutable=*/true);
}

// Handle special builtin types that for some reason can't materialize
// themselves.
if (arith::ConstantOp::isBuildableWith(constantValue.attr,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,37 @@ util.func @foo(%arg0: index) -> (index, index, index) {

// -----

// CHECK: util.global private @immutable_initializer_local
util.global private mutable @immutable_initializer_local : index
// CHECK: util.global private @immutable_initializer_callee
util.global private mutable @immutable_initializer_callee : index
// CHECK: util.global private mutable @mutable : index
util.global private mutable @mutable : index
util.func private @generate_value() -> index
util.initializer {
%value = util.call @generate_value() : () -> index
util.global.store %value, @immutable_initializer_local : index
util.return
}
util.func @public_func() -> (index, index, index) {
util.call @public_callee() : () -> ()
// CHECK-DAG: %[[LOCAL:.+]] = util.global.load immutable @immutable_initializer_local
%0 = util.global.load @immutable_initializer_local : index
// CHECK-DAG: %[[CALLEE:.+]] = util.global.load immutable @immutable_initializer_callee
%1 = util.global.load @immutable_initializer_callee : index
// CHECK-DAG: %[[MUTABLE:.+]] = util.global.load @mutable
%2 = util.global.load @mutable : index
// CHECK: return %[[LOCAL]], %[[CALLEE]], %[[MUTABLE]]
util.return %0, %1, %2 : index, index, index
}
util.func private @public_callee() {
%value = util.call @generate_value() : () -> index
util.global.store %value, @mutable : index
util.return
}

// -----

// CHECK: util.global private mutable @used0 = 5 : index
util.global private mutable @used0 = 5 : index
// CHECK: util.global private mutable @used1 : index
Expand Down
Loading

0 comments on commit e792c32

Please sign in to comment.