Skip to content

Commit

Permalink
[ConstEval] Teach global hoisting to build a dot graph of its analysis (
Browse files Browse the repository at this point in the history
  • Loading branch information
Groverkss authored Sep 30, 2023
1 parent c753168 commit 83df8c4
Show file tree
Hide file tree
Showing 3 changed files with 156 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,15 @@
#include "iree/compiler/Dialect/Util/IR/UtilOps.h"
#include "llvm/Support/CommandLine.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/GraphWriter.h"
#include "mlir/Dialect/Arith/IR/Arith.h"

#define DEBUG_TYPE "iree-constexpr"

using llvm::dbgs;

using namespace mlir::iree_compiler::IREE::Util;

namespace mlir {
namespace iree_compiler {
namespace IREE {
Expand Down Expand Up @@ -417,7 +420,57 @@ void ConstExprHoistingPolicy::makeDecision(
decision->enableHoist();
}

void ConstExprHoistingPolicy::printDotGraph(raw_ostream &os) const {
WriteGraph(os, this);
}

void ConstExprHoistingPolicy::dumpDotGraph() const {
printDotGraph(llvm::errs());
}

} // namespace Util
} // namespace IREE
} // namespace iree_compiler
} // namespace mlir

namespace llvm {
template <>
struct DOTGraphTraits<const ConstExprHoistingPolicy *>
: public DefaultDOTGraphTraits {
explicit DOTGraphTraits(bool isSimple = false)
: DefaultDOTGraphTraits(isSimple) {}

std::string getNodeLabel(const ConstExprAnalysis::ConstValueInfo *Node,
const ConstExprHoistingPolicy *g) {
std::string label;
llvm::raw_string_ostream os(label);
os << Node->constValue.getType();
return label;
}

static bool isNodeHidden(const ConstExprAnalysis::ConstValueInfo *Node,
const ConstExprHoistingPolicy *g) {
// Only display nodes that the analysis has determined to be const-expr.
return !Node->isConstExpr();
}

static std::string
getNodeAttributes(const ConstExprAnalysis::ConstValueInfo *Node,
const ConstExprHoistingPolicy *g) {
// Roots are colored red.
if (Node->isRoot)
return "fillcolor=red,style=filled";

// Hoisted values are colored green.
ConstExprHoistingPolicy::Outcome outcome = g->getOutcome(Node);
if (outcome == ConstExprHoistingPolicy::Outcome::ENABLE_HOIST)
return "fillcolor=green,style=filled";

return "";
}

static void
addCustomGraphFeatures(const ConstExprHoistingPolicy *g,
GraphWriter<const ConstExprHoistingPolicy *> &GW) {}
};
}; // namespace llvm
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
#include <vector>

#include "llvm/ADT/DenseMap.h"
#include "llvm/ADT/GraphTraits.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallPtrSet.h"
#include "llvm/ADT/SmallVector.h"
#include "mlir/IR/Operation.h"
Expand Down Expand Up @@ -127,6 +129,28 @@ class ConstExprAnalysis {
}
};

// Define an iterator over the second value of constInfoMap.
using ConstValueMapT = llvm::DenseMap<Value, ConstValueInfo *>;
class ConstValueIterator final
: public llvm::mapped_iterator<
ConstValueMapT::const_iterator,
ConstValueInfo *(*)(const ConstValueMapT::value_type &)> {

static ConstValueInfo *unwrap(const ConstValueMapT::value_type &value) {
return value.second;
}

public:
ConstValueIterator(ConstValueMapT::const_iterator it)
: llvm::mapped_iterator<
ConstValueMapT::const_iterator,
ConstValueInfo *(*)(const ConstValueMapT::value_type &)>(
it, &unwrap) {}
};

ConstValueIterator begin() const { return constInfoMap.begin(); }
ConstValueIterator end() const { return constInfoMap.end(); }

private:
// Expands the frontier to include all results of a given op in an UNKNOWN
// state. This also checks that all of its operands are known, adding
Expand All @@ -146,6 +170,8 @@ class ConstExprAnalysis {
// Map of analyzed value to corresponding info struct.
llvm::DenseMap<Value, ConstValueInfo *> constInfoMap;

// Define an iterator over std::unique_ptr<T> as a pointer range on T*.

// Allocated ConstValueInfo structs (to preserve pointer stability).
llvm::SmallVector<std::unique_ptr<ConstValueInfo>> allocedConstInfos;

Expand Down Expand Up @@ -188,21 +214,30 @@ class ConstExprHoistingPolicy {
Outcome outcome = UNDECIDED;
};

void printDotGraph(raw_ostream &os) const;
void dumpDotGraph() const;

const ConstExprAnalysis &getAnalysis() const { return analysis; }

ConstExprHoistingPolicy(const ConstExprAnalysis &analysis);
void initialize();
Decision *getDecision(const ConstExprAnalysis::ConstValueInfo *info) {
return &decisions[info];
}

Outcome getOutcome(const ConstExprAnalysis::ConstValueInfo *info) const {
return decisions.lookup(info).getOutcome();
}

private:
// At initialization time, makes any fixed decisions. This hook can only
// make decisions that do not depend on any const-exprs outside of what is
// passed.
void makeInvariantDecision(const ConstExprAnalysis::ConstValueInfo *info,
Decision *decision);
// Makes a decision that depends on producers and consumers of a value. This
// may be called repeatedly until convergence. The implementation should call
// decision.disableHoist() or decision.enableHoist() if it can reach a
// may be called repeatedly until convergence. The implementation should
// call decision.disableHoist() or decision.enableHoist() if it can reach a
// decision.
void makeDecision(const ConstExprAnalysis::ConstValueInfo *info,
Decision *decision);
Expand All @@ -225,4 +260,50 @@ inline raw_ostream &operator<<(raw_ostream &os,
} // namespace iree_compiler
} // namespace mlir

namespace llvm {
template <>
struct GraphTraits<
mlir::iree_compiler::IREE::Util::ConstExprAnalysis::ConstValueInfo *> {
using NodeRef =
mlir::iree_compiler::IREE::Util::ConstExprAnalysis::ConstValueInfo *;
using ChildIteratorType = SmallPtrSetImpl<NodeRef>::iterator;

static NodeRef getEntryNode(NodeRef info) { return info; }

static ChildIteratorType child_begin(NodeRef info) {
return info->consumers.begin();
}

static ChildIteratorType child_end(NodeRef info) {
return info->consumers.end();
}
};

template <>
struct GraphTraits<
const mlir::iree_compiler::IREE::Util::ConstExprHoistingPolicy *>
: public GraphTraits<mlir::iree_compiler::IREE::Util::ConstExprAnalysis::
ConstValueInfo *> {

using nodes_iterator =
mlir::iree_compiler::IREE::Util::ConstExprAnalysis::ConstValueIterator;

static NodeRef getEntryNode(
const mlir::iree_compiler::IREE::Util::ConstExprHoistingPolicy *graph) {
return *graph->getAnalysis().begin();
}

static nodes_iterator nodes_begin(
const mlir::iree_compiler::IREE::Util::ConstExprHoistingPolicy *graph) {
return graph->getAnalysis().begin();
}

static nodes_iterator nodes_end(
const mlir::iree_compiler::IREE::Util::ConstExprHoistingPolicy *graph) {
return graph->getAnalysis().end();
}
};

} // namespace llvm

#endif // IREE_COMPILER_DIALECT_IREE_UTIL_ANALYSIS_CONSTANT_CONST_EXPR_H_
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,13 @@ namespace IREE {
namespace Util {
namespace {

static llvm::cl::opt<std::string> clPrintDotGraphToFile(
"iree-util-hoist-into-globals-print-constexpr-dotgraph-to",
llvm::cl::desc(
"Prints a dot graph representing the const-expr analysis. The red "
"nodes represent roots and the green nodes represent hoisted values."),
llvm::cl::value_desc("filename"));

// Maps an original value in the program to the symbol name of a global.
using HoistedValueMap = llvm::DenseMap<Value, GlobalOp>;

Expand All @@ -48,6 +55,19 @@ class HoistIntoGlobalsPass : public HoistIntoGlobalsBase<HoistIntoGlobalsPass> {
ConstExprHoistingPolicy policy(constExprs);
policy.initialize();

// Print analysis dot graph if requested.
if (!clPrintDotGraphToFile.empty()) {
std::error_code ec;
llvm::raw_fd_ostream file(clPrintDotGraphToFile, ec);
if (ec) {
getOperation().emitError()
<< "failed to open file for printing dot graph: " << ec.message();
return signalPassFailure();
}
policy.printDotGraph(file);
file.close();
}

// Maps original values to newly materialized values.
HoistedValueMap hoistedMap;

Expand Down

0 comments on commit 83df8c4

Please sign in to comment.