From 215de20b010d9fd139168a61369fa7366e6e95d4 Mon Sep 17 00:00:00 2001 From: William Huang Date: Fri, 16 Jun 2017 05:23:14 -0400 Subject: [PATCH 1/2] Copy Propagation --- src/CMakeLists.txt | 1 + src/backend/backend_visitor.cpp | 5 +- src/intrusive_ptr_hash.h | 37 ++ src/ir.cpp | 23 +- src/ir.h | 2 +- src/ir_pattern_matching.h | 271 +++++++++ src/ir_printer.cpp | 6 +- src/ir_queries.h | 22 +- src/ir_rewriter.cpp | 45 +- src/ir_visitor.cpp | 5 +- src/lower/lower.cpp | 5 + src/program_analysis/CMakeLists.txt | 5 + src/program_analysis/cp.cpp | 641 ++++++++++++++++++++++ src/program_analysis/node_replacer.cpp | 73 +++ src/program_analysis/node_replacer.h | 30 + src/program_analysis/program_analysis.cpp | 10 + src/program_analysis/program_analysis.h | 21 + src/types.h | 14 + 18 files changed, 1160 insertions(+), 56 deletions(-) create mode 100644 src/intrusive_ptr_hash.h create mode 100644 src/ir_pattern_matching.h create mode 100644 src/program_analysis/CMakeLists.txt create mode 100644 src/program_analysis/cp.cpp create mode 100644 src/program_analysis/node_replacer.cpp create mode 100644 src/program_analysis/node_replacer.h create mode 100644 src/program_analysis/program_analysis.cpp create mode 100644 src/program_analysis/program_analysis.h diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index e3291d42..30f22668 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -27,6 +27,7 @@ add_subdirectory(backend) add_subdirectory(frontend) add_subdirectory(lower) add_subdirectory(visualizer) +add_subdirectory(program_analysis) add_definitions(${SIMIT_DEFINITIONS}) include_directories(${SIMIT_INCLUDE_DIRS}) diff --git a/src/backend/backend_visitor.cpp b/src/backend/backend_visitor.cpp index a0672865..bc8cf32d 100644 --- a/src/backend/backend_visitor.cpp +++ b/src/backend/backend_visitor.cpp @@ -34,9 +34,8 @@ void BackendVisitorBase::compile(const ir::GPUKernel& kernel) { #endif void BackendVisitorBase::compile(const ir::Block& op) { - op.first.accept(this); - if (op.rest.defined()) { - op.rest.accept(this); + for (Stmt stmt : op.stmts) { + stmt.accept(this); } } diff --git a/src/intrusive_ptr_hash.h b/src/intrusive_ptr_hash.h new file mode 100644 index 00000000..ba5dabff --- /dev/null +++ b/src/intrusive_ptr_hash.h @@ -0,0 +1,37 @@ +#ifndef SIMIT_INTRUSIVE_PTR_HASH_H +#define SIMIT_INTRUSIVE_PTR_HASH_H + +#include + +#include "intrusive_ptr.h" +#include "ir.h" + +namespace std { + +template +struct hash> { + size_t operator()(const simit::util::IntrusivePtr& x) const { + return (size_t) x.ptr; + } +}; + +#define HASH_INTRUSIVE_PTR(type) template <>\ + struct hash {\ + size_t operator()(const type& x) const {\ + return (size_t) x.ptr;\ + }\ + }; + +// SFINAE on detecting base class doesn't work with clang, need to wait for C++17. Workaround here + +HASH_INTRUSIVE_PTR(simit::ir::Var) +HASH_INTRUSIVE_PTR(simit::ir::IndexVar) +HASH_INTRUSIVE_PTR(simit::ir::Expr) +HASH_INTRUSIVE_PTR(simit::ir::Stmt) +HASH_INTRUSIVE_PTR(simit::ir::Func) + +#undef HASH_INTRUSIVE_PTR + +} + +#endif diff --git a/src/ir.cpp b/src/ir.cpp index 554b5ddd..1e52e4c4 100644 --- a/src/ir.cpp +++ b/src/ir.cpp @@ -837,17 +837,28 @@ Stmt Block::make(Stmt first, Stmt rest) { } Block *node = new Block; - node->first = first; - node->rest = rest; + if (isa(first)) { + const Block* block = to(first); + node->stmts = block->stmts; + } else { + node->stmts = {first}; + } + + if (rest.defined()) { + if (isa(rest)) { + const Block* block = to(rest); + node->stmts.insert(node->stmts.end(), block->stmts.begin(), block->stmts.end()); + } else { + node->stmts.push_back(rest); + } + } return node; } Stmt Block::make(std::vector stmts) { iassert(stmts.size() > 0) << "Empty block"; - Stmt node; - for (size_t i=stmts.size(); i>0; --i) { - node = Block::make(stmts[i-1], node); - } + Block *node = new Block; + node->stmts = stmts; return node; } diff --git a/src/ir.h b/src/ir.h index 238e8835..b432fe3b 100644 --- a/src/ir.h +++ b/src/ir.h @@ -439,7 +439,7 @@ struct Kernel : public StmtNode { }; struct Block : public StmtNode { - Stmt first, rest; + vector stmts; static Stmt make(Stmt first, Stmt rest); static Stmt make(std::vector stmts); void accept(IRVisitorStrict *v) const {v->visit((const Block*)this);} diff --git a/src/ir_pattern_matching.h b/src/ir_pattern_matching.h new file mode 100644 index 00000000..f0623f5a --- /dev/null +++ b/src/ir_pattern_matching.h @@ -0,0 +1,271 @@ +#ifndef SIMIT_IR_PATTERN_MATCHING_H +#define SIMIT_IR_PATTERN_MATCHING_H + +#include "ir.h" +#include +#include +#include + +namespace simit { +namespace ir { + +struct HandleCallbackArguments { + template + static bool call(const T* node, std::nullptr_t) { + return true; + } + +#if __cplusplus < 201700L + template ::type>::value>::type> + static bool call(const T* node, U&& func) { + return func(node); + } + + template ::type>::value>::type> + static bool call(const T* node, U&& func, int* dummy = 0) { + func(node); + return true; + } +#else + template ::type>::value>::type> + static bool call(const T* node, U&& func) { + return func(node); + } + + template ::type>::value>::type> + static bool call(const T* node, U&& func, int* dummy = 0) { + func(node); + return true; + } +#endif +}; + +template +class PatternMatch; + +template +class PatternMatch : public PatternMatch { +}; + +template <> +class PatternMatch { +public: + template + static bool match(ExprOrStmt expr, Callback&&) { + return true; + } +}; + +#define GENERATE_PATTERNMATCH_0(TYPE)\ + template <>\ + class PatternMatch {\ + public:\ + template \ + static bool match(ExprOrStmt expr, std::nullptr_t) {\ + return isa(expr);\ + }\ + \ + template \ + static bool match(ExprOrStmt expr, const std::tuple&) {\ + return isa(expr);\ + }\ + \ + template \ + static bool match(ExprOrStmt expr, const std::tuple& callbacks) {\ + if (isa(expr)) {\ + const TYPE* node = to(expr);\ + return HandleCallbackArguments::call(node, std::get<0>(callbacks));\ + }\ + return false;\ + }\ + }; + +#define GENERATE_PATTERNMATCH_1(TYPE, MEMBER1)\ + template \ + class PatternMatch {\ + public:\ + template \ + static bool match(ExprOrStmt expr, std::nullptr_t) {\ + return isa(expr);\ + }\ + \ + template \ + static bool match(ExprOrStmt expr, const std::tuple&) {\ + return isa(expr);\ + }\ + \ + template \ + static bool match(ExprOrStmt expr, const std::tuple& callbacks) {\ + if (isa(expr)) {\ + const TYPE* node = to(expr);\ + return HandleCallbackArguments::call(node, std::get<0>(callbacks)) &&\ + PatternMatch::match(node->MEMBER1, std::get<1>(callbacks));\ + }\ + return false;\ + }\ + }; + +#define GENERATE_PATTERNMATCH_2(TYPE, MEMBER1, MEMBER2)\ + template \ + class PatternMatch {\ + public:\ + template \ + static bool match(ExprOrStmt expr, std::nullptr_t) {\ + return isa(expr);\ + }\ + \ + template \ + static bool match(ExprOrStmt expr, const std::tuple&) {\ + return isa(expr);\ + }\ + \ + template \ + static bool match(ExprOrStmt expr, const std::tuple& callbacks) {\ + if (isa(expr)) {\ + const TYPE* node = to(expr);\ + return HandleCallbackArguments::call(node, std::get<0>(callbacks)) &&\ + PatternMatch::match(node->MEMBER1, std::get<1>(callbacks)) &&\ + PatternMatch::match(node->MEMBER2, std::get<2>(callbacks));\ + }\ + return false;\ + }\ + }; + +#define GENERATE_PATTERNMATCH_3(TYPE, MEMBER1, MEMBER2, MEMBER3)\ + template \ + class PatternMatch {\ + public:\ + template \ + static bool match(ExprOrStmt expr, std::nullptr_t) {\ + return isa(expr);\ + }\ + \ + template \ + static bool match(ExprOrStmt expr, const std::tuple&) {\ + return isa(expr);\ + }\ + \ + template \ + static bool match(ExprOrStmt expr, const std::tuple& callbacks) {\ + if (isa(expr)) {\ + const TYPE* node = to(expr);\ + return HandleCallbackArguments::call(node, std::get<0>(callbacks)) &&\ + PatternMatch::match(node->MEMBER1, std::get<1>(callbacks)) &&\ + PatternMatch::match(node->MEMBER2, std::get<2>(callbacks)) &&\ + PatternMatch::match(node->MEMBER3, std::get<3>(callbacks));\ + }\ + return false;\ + }\ + }; + +GENERATE_PATTERNMATCH_0(Literal) +GENERATE_PATTERNMATCH_0(VarExpr) +GENERATE_PATTERNMATCH_2(Load, buffer, index) +GENERATE_PATTERNMATCH_1(FieldRead, elementOrSet) +GENERATE_PATTERNMATCH_0(Length) +GENERATE_PATTERNMATCH_1(IndexRead, edgeSet) + +template +class PatternMatch { +public: + template + static bool match(ExprOrStmt expr, std::nullptr_t) { + return isa(expr); + } + + template + static bool match(ExprOrStmt expr, const std::tuple&) { + return isa(expr); + } + + template + static bool match(ExprOrStmt expr, const std::tuple& callbacks) { + if (isa(expr)) { + const UnaryExpr* node = to(expr); + return HandleCallbackArguments::call(node, std::get<0>(callbacks)) && + PatternMatch::match(node->a, std::get<1>(callbacks)); + } + return false; + } +}; + +template +class PatternMatch { +public: + template + static bool match(ExprOrStmt expr, std::nullptr_t) { + return isa(expr); + } + + template + static bool match(ExprOrStmt expr, const std::tuple&) { + return isa(expr); + } + + template + static bool match(ExprOrStmt expr, const std::tuple& callbacks) { + if (isa(expr)) { + const BinaryExpr* node = to(expr); + return HandleCallbackArguments::call(node, std::get<0>(callbacks)) && + PatternMatch::match(node->a, std::get<1>(callbacks)) && + PatternMatch::match(node->b, std::get<2>(callbacks)); + } + return false; + } +}; + +GENERATE_PATTERNMATCH_1(Neg, a) +GENERATE_PATTERNMATCH_2(Add, a, b) +GENERATE_PATTERNMATCH_2(Sub, a, b) +GENERATE_PATTERNMATCH_2(Mul, a, b) +GENERATE_PATTERNMATCH_2(Div, a, b) +GENERATE_PATTERNMATCH_2(Rem, a, b) + +GENERATE_PATTERNMATCH_1(Not, a) +GENERATE_PATTERNMATCH_2(Eq, a, b) +GENERATE_PATTERNMATCH_2(Ne, a, b) +GENERATE_PATTERNMATCH_2(Gt, a, b) +GENERATE_PATTERNMATCH_2(Lt, a, b) +GENERATE_PATTERNMATCH_2(Ge, a, b) +GENERATE_PATTERNMATCH_2(Le, a, b) +GENERATE_PATTERNMATCH_2(And, a, b) +GENERATE_PATTERNMATCH_2(Or, a, b) +GENERATE_PATTERNMATCH_2(Xor, a, b) + +GENERATE_PATTERNMATCH_0(VarDecl) +GENERATE_PATTERNMATCH_1(AssignStmt, value) +GENERATE_PATTERNMATCH_3(Store, buffer, index, value) +GENERATE_PATTERNMATCH_2(FieldWrite, elementOrSet, value) + +// call + +GENERATE_PATTERNMATCH_1(Scope, scopedStmt) +GENERATE_PATTERNMATCH_3(IfThenElse, condition, thenBody, elseBody) +GENERATE_PATTERNMATCH_3(ForRange, start, end, body) +GENERATE_PATTERNMATCH_1(For, body) +GENERATE_PATTERNMATCH_2(While, condition, body) +GENERATE_PATTERNMATCH_1(Kernel, body) + +// block + +GENERATE_PATTERNMATCH_1(Print, expr) +GENERATE_PATTERNMATCH_1(Comment,commentedStmt) +GENERATE_PATTERNMATCH_0(Pass) + +GENERATE_PATTERNMATCH_2(UnnamedTupleRead, tuple, index) +GENERATE_PATTERNMATCH_1(NamedTupleRead, tuple) +// setread +// tensorread +//GENERATE_PATTERNMATCH_(TensorWrite) +GENERATE_PATTERNMATCH_1(IndexedTensor, tensor) +GENERATE_PATTERNMATCH_1(IndexExpr, value) +//GENERATE_PATTERNMATCH_(Map) + +} +} + +#endif diff --git a/src/ir_printer.cpp b/src/ir_printer.cpp index c0d390fb..30ac74c8 100644 --- a/src/ir_printer.cpp +++ b/src/ir_printer.cpp @@ -468,10 +468,10 @@ void IRPrinter::visit(const Kernel *op) { } void IRPrinter::visit(const Block *op) { - print(op->first); - if (op->rest.defined()) { + print(op->stmts.front()); + for (size_t i = 1; i < op->stmts.size(); i++) { os << endl; - print(op->rest); + print(op->stmts[i]); } } diff --git a/src/ir_queries.h b/src/ir_queries.h index c151b624..2359d364 100644 --- a/src/ir_queries.h +++ b/src/ir_queries.h @@ -112,26 +112,24 @@ inline std::vector splitOnPredicate( } void visit(const Block* op) { - Stmt first, rest; + vector newStmts; if (visitFirst) { // Visit in first-to-last order - if (!visitDone) { - first = rewrite(op->first); - } - if (!visitDone) { - rest = rewrite(op->rest); + for (size_t i = 0; i < op->stmts.size(); i++) { + if (!visitDone) { + newStmts.push_back(rewrite(op->stmts[i])); + } } } else { // Visit in last-to-first order - if (!visitDone) { - rest = rewrite(op->rest); - } - if (!visitDone) { - first = rewrite(op->first); + for (size_t i = op->stmts.size(); i > 0; i--) { + if (!visitDone) { + newStmts.push_back(rewrite(op->stmts[i - 1])); + } } } - stmt = Block::make(first, rest); + stmt = Block::make(newStmts); } void visit(const IfThenElse *op) { diff --git a/src/ir_rewriter.cpp b/src/ir_rewriter.cpp index da519fe4..9b491e28 100644 --- a/src/ir_rewriter.cpp +++ b/src/ir_rewriter.cpp @@ -10,9 +10,6 @@ Expr IRRewriter::rewrite(Expr e) { e.accept(this); e = expr; } - else { - e = Expr(); - } expr = Expr(); stmt = Stmt(); func = Func(); @@ -28,9 +25,6 @@ Stmt IRRewriter::rewrite(Stmt s) { } s = stmt; } - else { - s = Stmt(); - } expr = Expr(); stmt = Stmt(); func = Func(); @@ -42,9 +36,6 @@ Func IRRewriter::rewrite(Func f) { f.accept(this); f = func; } - else { - f = Func(); - } expr = Expr(); stmt = Stmt(); func = Func(); @@ -285,7 +276,7 @@ void IRRewriter::visit(const ForRange *op) { Expr end = rewrite(op->end); Stmt spilledBounds = getSpilledStmts(); Stmt body = rewrite(op->body); - + if (body == op->body && start == op->start && end == op->end) { stmt = op; } @@ -311,7 +302,7 @@ void IRRewriter::visit(const While *op) { Expr condition = rewrite(op->condition); Stmt spilledCond = getSpilledStmts(); Stmt body = rewrite(op->body); - + if (condition == op->condition && body == op->body) { stmt = op; } @@ -337,22 +328,20 @@ void IRRewriter::visit(const Kernel *op) { } void IRRewriter::visit(const Block *op) { - Stmt first = rewrite(op->first); - Stmt rest = rewrite(op->rest); - if (first == op->first && rest == op->rest) { - stmt = op; + vector newStmts; + + for (Stmt stmt : op->stmts) { + newStmts.push_back(rewrite(stmt)); } - else { - if (first.defined() && rest.defined()) { - stmt = Block::make(first, rest); - } - else if (first.defined() && !rest.defined()) { - stmt = first; - } - else if (!first.defined() && rest.defined()) { - stmt = rest; - } - else { + + if (newStmts == op->stmts) { + stmt = op; + } else { + newStmts.resize(std::remove_if(newStmts.begin(), newStmts.end(), + [](Stmt s){return !s.defined();}) - newStmts.begin()); + if (newStmts.size()) { + stmt = Block::make(newStmts); + } else { stmt = Stmt(); } } @@ -499,7 +488,7 @@ void IRRewriter::visit(const Map *op) { if (op->through.defined()) { through = rewrite(op->through); } - + std::vector partial_actuals(op->partial_actuals.size()); bool actualsSame = true; for (size_t i=0; i < op->partial_actuals.size(); ++i) { @@ -509,7 +498,7 @@ void IRRewriter::visit(const Map *op) { } } - if (target == op->target && through == op->through && + if (target == op->target && through == op->through && neighborsSame && actualsSame) { stmt = op; } diff --git a/src/ir_visitor.cpp b/src/ir_visitor.cpp index 43d7d51e..b68026a9 100644 --- a/src/ir_visitor.cpp +++ b/src/ir_visitor.cpp @@ -168,9 +168,8 @@ void IRVisitor::visit(const Kernel *op) { } void IRVisitor::visit(const Block *op) { - op->first.accept(this); - if (op->rest.defined()) { - op->rest.accept(this); + for (Stmt stmt : op->stmts) { + stmt.accept(this); } } diff --git a/src/lower/lower.cpp b/src/lower/lower.cpp index f9420700..2556ee1a 100644 --- a/src/lower/lower.cpp +++ b/src/lower/lower.cpp @@ -22,6 +22,7 @@ #include "ir_transforms.h" #include "ir_printer.h" #include "path_expressions.h" +#include "program_analysis/program_analysis.h" #ifdef GPU #include "backend/gpu/gpu_backend.h" @@ -101,6 +102,10 @@ Func lower(Func func, std::ostream* os, bool time) { } #endif + // Program analysis + func = rewriteCallGraph(func, program_analysis::program_analysis); + printCallGraph("Program analysis", func, os); + // Inline function calls func = rewriteCallGraph(func, inlineCalls); printCallGraph("Inline Function Calls", func, os); diff --git a/src/program_analysis/CMakeLists.txt b/src/program_analysis/CMakeLists.txt new file mode 100644 index 00000000..03b6be22 --- /dev/null +++ b/src/program_analysis/CMakeLists.txt @@ -0,0 +1,5 @@ +file(GLOB SIMIT_HEADERS ${SIMIT_HEADERS} *.h) +file(GLOB SIMIT_SOURCES ${SIMIT_SOURCES} *.cpp) + +set(SIMIT_SOURCES ${SIMIT_SOURCES} PARENT_SCOPE) +set(SIMIT_HEADERS ${SIMIT_HEADERS} PARENT_SCOPE) diff --git a/src/program_analysis/cp.cpp b/src/program_analysis/cp.cpp new file mode 100644 index 00000000..e681b93f --- /dev/null +++ b/src/program_analysis/cp.cpp @@ -0,0 +1,641 @@ +#include "program_analysis.h" +#include "intrusive_ptr_hash.h" +#include "ir_pattern_matching.h" +#include "ir_builder.h" +#include "node_replacer.h" +#include "util/collections.h" + +#include +#include +#include + +using namespace std; +using namespace simit::util; +using namespace simit::ir; +using namespace simit::ir::program_analysis; + +typedef vector IndexVars; + +static bool exprEq(Expr a, Expr b) { + if (isa(a) && isa(b)) { + const VarExpr* ea = to(a); + const VarExpr* eb = to(b); + return ea->var == eb->var; + } else if (isa(a) && isa(b)) { + const Literal* ea = to(a); + const Literal* eb = to(b); + return *ea == *eb; + } + return false; +} + +namespace { +struct IndexVarComp { + bool operator()(const IndexVar& a, const IndexVar& b) const { + return a.ptr == b.ptr; + } +}; +} + +typedef unordered_map, IndexVarComp> Mapping; + +static bool isSurjective(const IndexVar& domain, const IndexVar& image, + Mapping& mapping, Mapping& /*dummy*/) { + iassert(!domain.isReductionVar()); + + if (domain.isFixed()) { + if (!IndexVarComp()(domain, image)) { + return false; + } + } else { + auto it = mapping.find(domain); + if (it != mapping.end()) { + if (!IndexVarComp()(it->second, image)) { + return false; + } + } else { + mapping[domain] = image; + } + } + return true; +} + +static bool isBijective(const IndexVar& domain, const IndexVar& image, + Mapping& mapping, Mapping& reverseMapping) { + return isSurjective(domain, image, mapping, mapping) && + isSurjective(image, domain, reverseMapping, reverseMapping); +} + +static bool checkIndicesMapping(const IndexVars& domain, const IndexVars& image, + bool(*projection)(const IndexVar&, const IndexVar&, Mapping&, Mapping&)) { + if (domain.size() != image.size()) { + return false; + } + + Mapping mapping, reverseMapping; + for (auto&& pair : ZipConstIterable(domain, image)) { + if (!projection(pair.first, pair.second, mapping, reverseMapping)) { + return false; + } + } + return true; +} + +static +bool checkIndicesMapping(const IndexVars& domain1, const IndexVars& domain2, + const IndexVars& image1, const IndexVars& image2, + bool(*projection)(const IndexVar&, const IndexVar&, Mapping&, Mapping&)) { + if (domain1.size() != image1.size() || domain2.size() != image2.size()) { + return false; + } + + Mapping mapping, reverseMapping; + for (auto&& pair : ZipConstIterable(domain1, image1)) { + if (!projection(pair.first, pair.second, mapping, reverseMapping)) { + return false; + } + } + for (auto&& pair : ZipConstIterable(domain2, image2)) { + if (!projection(pair.first, pair.second, mapping, reverseMapping)) { + return false; + } + } + + return true; +} + +static bool mapIndices(const IndexVars& domain1, const IndexVars& domain2, + const IndexVars& image1, IndexVars& image2, + bool(*projection)(const IndexVar&, const IndexVar&, Mapping&, Mapping&)) { + if (domain1.size() != image1.size()) { + return false; + } + + Mapping mapping, reverseMapping; + for (auto&& pair : ZipConstIterable(domain1, image1)) { + if (!projection(pair.first, pair.second, mapping, reverseMapping)) { + return false; + } + } + + for (const IndexVar& domain : domain2) { + iassert(!domain.isReductionVar()); + } + + image2.clear(); + for (const IndexVar& domain : domain2) { + if (domain.isFixed()) { + image2.push_back(domain); + } else { + IndexVar image; + auto it = mapping.find(domain); + if (it != mapping.end()) { + image = it->second; + } else { + image = domain; + } + image2.push_back(image); + } + } + return true; +} + +namespace { +class UseDefGraph { + struct Node { + Node* head; + vector tails; + Expr value; + IndexVars indices; + + Node(Expr value, const IndexVars& indices) + : head(nullptr), value(value), indices(indices) { + iassert(isa(value) || isa(value)); + } + + bool operator==(const Node& other) const { + return exprEq(this->value, other.value); + } + }; + + struct NodeIterator { + Node* node; + + NodeIterator(Node* node) : node(node) {} + + bool operator!=(const NodeIterator& other) const { + return node != other.node; + } + + void operator++() { + node = node->head; + } + + Node* operator*() { + return node; + } + + NodeIterator begin() { + return NodeIterator(node); + } + + NodeIterator end() { + return NodeIterator(nullptr); + } + }; + + struct NodeHash { + size_t operator() (const Node* const& node) const { + if (isa(node->value)) { + const VarExpr* e = to(node->value); + return (size_t) e->var.ptr; + } else { + iassert(isa(node->value)); + return 0; + } + } + + bool operator() (const Node* const& a, const Node* const& b) const { + return a == b || a && b && *a == *b; + } + }; + + unordered_multiset data; +public: + UseDefGraph() = default; + + UseDefGraph(const UseDefGraph& other) { + unordered_map oldNewMap; + for (const Node* oldNode : other.data) { + Node* newNode = new Node(oldNode->value, oldNode->indices); + oldNewMap[oldNode] = newNode; + data.insert(newNode); + } + + for (const Node* oldNode : other.data) { + Node* newNode = oldNewMap[oldNode]; + newNode->head = oldNewMap[oldNode->head]; + if (newNode->head) { + newNode->head->tails.push_back(newNode); + } + } + } + + ~UseDefGraph() { + for (Node* node : data) { + delete node; + } + } + + void swap(UseDefGraph& other) { + data.swap(other.data); + } + + void insert(Expr destVal, const IndexVars& destIndices, + Expr srcVal, const IndexVars& srcIndices) { + // Due to bug in self assignment (A = A'), if dest == src, then dest is + // invalidated regardless of indices, unless indices exactly match + iassert(isa(destVal)); + + if (exprEq(destVal, srcVal) && destIndices == srcIndices) { + return; + } + + // handle repeated assignment + Node dest1(destVal, destIndices); + for (auto it_pair = data.equal_range(&dest1); + it_pair.first != it_pair.second; ++it_pair.first) { + Node* destOriginal = *it_pair.first; + if (checkIndicesMapping(destOriginal->indices, destIndices, isSurjective)) { + for (Node* srcOriginal : NodeIterator(destOriginal)) { + if (exprEq(srcVal, srcOriginal->value) && + checkIndicesMapping(destOriginal->indices, srcOriginal->indices, + destIndices, srcIndices, isSurjective)) { + return; + } + } + } + } + + erase(destVal, destIndices); + Node src(srcVal, srcIndices); + Node* headCandidate = nullptr; + IndexVars mappedIndicesCandidate; + + for (auto it_pair = data.equal_range(&src); + it_pair.first != it_pair.second; ++it_pair.first) { + Node* head = *it_pair.first; + IndexVars mappedIndices; + if (mapIndices(srcIndices, destIndices, + head->indices, mappedIndices, isBijective)) { + iassert(!headCandidate); + headCandidate = head; + mappedIndicesCandidate = mappedIndices; + } + } + + Node* dest; + if (headCandidate) { + dest = new Node(destVal, mappedIndicesCandidate); + } else { + dest = new Node(destVal, destIndices); + headCandidate = new Node(srcVal, srcIndices); + data.insert(headCandidate); + } + data.insert(dest); + dest->head = headCandidate; + headCandidate->tails.push_back(dest); + } + + void erase(Expr val, const IndexVars& indices = {}) { + // Erase if var match && index var interfere + Node node(val, indices); + for (auto it_pair = data.equal_range(&node); + it_pair.first != it_pair.second; + it_pair.first = data.erase(it_pair.first)) { + const Node* value = *it_pair.first; + if (Node* head = value->head) { + head->tails.erase(find(head->tails.begin(), head->tails.end(), value)); + head->tails.insert(head->tails.end(), + value->tails.begin(), value->tails.end()); + } + for (Node* other : value->tails) { + other->head = value->head; + } + delete value; + } + } + + bool get(Expr val, const IndexVars& indices, + Expr& resultVar, IndexVars& resultIndices) { + Node node(val, indices); + for (auto it_pair = data.equal_range(&node); + it_pair.first != it_pair.second; ++it_pair.first) { + const Node* tail = *it_pair.first; + const Node* head = tail; + while (head->head) { + head = head->head; + } + // tail->indices contains more info than indices + if (!exprEq(val, head->value) && + mapIndices(tail->indices, head->indices, + indices, resultIndices, isSurjective)) { + resultVar = head->value; + return true; + } + } + return false; + } + + bool merge(const UseDefGraph& other) { + bool merged = true; + + for (Node* selfDest : data) { + bool notFound = true; + for (auto it_pair = other.data.equal_range(selfDest); + notFound && it_pair.first != it_pair.second; ++it_pair.first) { + const Node* otherDest = *it_pair.first; + if (checkIndicesMapping(selfDest->indices, otherDest->indices, + isBijective)) { + if (selfDest->head) { + NodeIterator selfSrcIterator(selfDest->head); + for (auto selfSrcBegin = selfSrcIterator.begin(), + selfSrcEnd = selfSrcIterator.end(); + notFound && selfSrcBegin != selfSrcEnd; ++selfSrcBegin) { + Node* selfSrc = *selfSrcBegin; + NodeIterator otherSrcIterator(otherDest->head); + for (auto otherSrcBegin = otherSrcIterator.begin(), + otherSrcEnd = otherSrcIterator.end(); + notFound && otherSrcBegin != otherSrcEnd; ++otherSrcBegin) { + Node* otherSrc = *otherSrcBegin; + if (exprEq(selfSrc->value, otherSrc->value) && + checkIndicesMapping(selfDest->indices, selfSrc->indices, + otherDest->indices, otherSrc->indices, isBijective)) { + notFound = false; + } + } + } + } else { + notFound = false; + } + } + } + + if (notFound) { + merged = false; + if (Node* selfSrc = selfDest->head) { + selfSrc->tails.erase(find(selfSrc->tails.begin(), + selfSrc->tails.end(), selfDest)); + selfDest->head = nullptr; + } + } + } + + for (auto selfIt = data.begin(); selfIt != data.end();) { + Node* selfNode = *selfIt; + if (!selfNode->head && !selfNode->tails.size()) { + delete selfNode; + selfIt = data.erase(selfIt); + merged = false; + } else { + ++selfIt; + } + } + + return merged; + } +}; +} + +static bool isReduction(const IndexVar& iv) { + return iv.isReductionVar(); +} + +static Expr getBuffer(Expr e) { + struct : public IRVisitor { + Expr result; + + virtual void visit(const VarExpr* e) { + result = e; + } + + virtual void visit(const Load* op) { op->buffer.accept(this); } + + virtual void visit(const FieldRead* op) { op->elementOrSet.accept(this); } + + virtual void visit(const IndexRead* op) { op->edgeSet.accept(this); } + + virtual void visit(const Store* op) { op->buffer.accept(this); } + + virtual void visit(const FieldWrite* op) { op->elementOrSet.accept(this); } + + virtual void visit(const UnnamedTupleRead* op) { op->tuple.accept(this); } + + virtual void visit(const NamedTupleRead* op) { op->tuple.accept(this); } + + virtual void visit(const SetRead* op) { op->set.accept(this); } + + virtual void visit(const TensorRead* op) { op->tensor.accept(this); } + + virtual void visit(const TensorWrite* op) { op->tensor.accept(this); } + } visitor; + e.accept(&visitor); + return visitor.result; +} + +Func CSE::rewrite(Func func) { + struct : public NodeReplacer { + UseDefGraph state; + + virtual void visit(const VarExpr* op) { + Expr mappedVar; + IndexVars mappedIndices; + if (state.get(op, {}, mappedVar, mappedIndices)) { + Expr replace = mappedVar; + if (mappedIndices.size()) { + replace = IndexedTensor::make(replace, mappedIndices); + } + exprReplacement = replace; + } + } + + virtual void visit(const IndexedTensor* op) { + Expr originalVar, mappedVar; + IndexVars originalIndices, mappedIndices; + + auto&& matchIT = [&](const IndexedTensor* e) { + originalIndices = e->indexVars; + }; + + auto&& matchVE = [&](const VarExpr* e) { + originalVar = e; + return e->type.isTensor(); + }; + + if (PatternMatch::match( + op, make_tuple(matchIT, make_tuple(matchVE))) && + state.get(originalVar, originalIndices, mappedVar, mappedIndices)) { + exprReplacement = IndexedTensor::make(mappedVar, mappedIndices); + } else { + NodeReplacer::visit(op); + } + } + + virtual void visit(const AssignStmt *op) { + NodeReplacer::visit(op); + + Expr src, dest; + IndexVars srcIndices, destIndices; + + auto&& matchAssign = [&](const AssignStmt* stmt) { + dest = VarExpr::make(stmt->var); + return stmt->cop == CompoundOperator::None; + }; + + auto&& matchIE = [&](const IndexExpr* e) { + destIndices = e->resultVars; + return none_of(e->resultVars.begin(), e->resultVars.end(), isReduction); + }; + + auto&& matchIT = [&](const IndexedTensor* e) { + srcIndices = e->indexVars; + return none_of(e->indexVars.begin(), e->indexVars.end(), isReduction); + }; + + auto&& matchVE = [&](const VarExpr* e) { + src = e; + return e->type.isTensor(); + }; + + auto&& matchLiteral = [&](const Literal* e) { + src = e; + return e->type.isTensor(); + }; + + if (PatternMatch::match + (op, make_tuple(matchAssign, make_tuple(matchIE, make_tuple(matchIT, + make_tuple(matchVE))))) || + PatternMatch::match + (op, make_tuple(matchAssign, make_tuple(matchIE, make_tuple(matchIT, + make_tuple(matchLiteral))))) || + PatternMatch::match(op, make_tuple(matchAssign, + make_tuple(matchVE))) || + PatternMatch::match(op, make_tuple(matchAssign, + make_tuple(matchLiteral)))) { + if (isa(src) && isScalar(src.type()) && + dest.type().toTensor()->order() > 0) { + if (isFixedSizeTensor(dest.type())) { + // Expand constant scalar to constant tensor + const Literal* srcLiteral = to(src); + const size_t n = dest.type().toTensor()->size(); + const size_t size = srcLiteral->size; + uint8_t* value = new uint8_t[n * size]; + for (size_t i = 0; i < n; i++) { + memcpy(value + i * size, srcLiteral->data, size); + } + src = Literal::make(dest.type(), value, n * size); + delete[] value; + + auto&& indexSets = dest.type().toTensor()->getOuterDimensions(); + IndexVarFactory ivf; + for (auto&& indexSet : indexSets) { + IndexVar iv = ivf.createIndexVar(IndexDomain(indexSet)); + srcIndices.push_back(iv); + destIndices.push_back(iv); + } + } else { + /// Limitation: can't do constant folding for variable sized tensor + state.erase(dest); + return; + } + } + state.insert(dest, destIndices, src, srcIndices); + } else { + state.erase(dest); + } + } + + virtual void visit(const IfThenElse *op) { + rewrite(op->condition); + + UseDefGraph copyState = state; + rewrite(op->thenBody); + state.swap(copyState); + if (op->elseBody.defined()) { + rewrite(op->elseBody); + } + state.merge(copyState); + } + + virtual void visit(const ForRange *op) { + rewrite(op->start); + rewrite(op->end); + + disable(); + state.erase(op->var); + + bool merged; + do { + UseDefGraph copyState = state; + rewrite(op->body); + state.erase(op->var); + merged = copyState.merge(state); + state.swap(copyState); + } while (!merged); + + enable(); + UseDefGraph copyState = state; + rewrite(op->body); + state.swap(copyState); + } + + virtual void visit(const For *op) { + disable(); + state.erase(op->var); + + bool merged; + do { + UseDefGraph copyState = state; + rewrite(op->body); + state.erase(op->var); + merged = copyState.merge(state); + state.swap(copyState); + } while (!merged); + + enable(); + UseDefGraph copyState = state; + rewrite(op->body); + state.swap(copyState); + } + + virtual void visit(const While *op) { + disable(); + + bool merged; + do { + UseDefGraph copyState = state; + rewrite(op->condition); + rewrite(op->body); + merged = copyState.merge(state); + state.swap(copyState); + } while (!merged); + + enable(); + rewrite(op->condition); + UseDefGraph copyState = state; + rewrite(op->body); + state.swap(copyState); + } + + virtual void visit(const CallStmt *op) { + NodeReplacer::visit(op); + for (Var v : op->results) { + state.erase(v); + } + } + + virtual void visit(const Store *op) { + NodeReplacer::visit(op); + state.erase(getBuffer(op->buffer)); + } + + virtual void visit(const FieldWrite *op) { + NodeReplacer::visit(op); + state.erase(getBuffer(op->elementOrSet)); + } + + virtual void visit(const TensorWrite *op) { + NodeReplacer::visit(op); + state.erase(getBuffer(op->tensor)); + } + + virtual void visit(const Map *op) { + NodeReplacer::visit(op); + for (Var v : op->vars) { + state.erase(v); + } + } + } visitor; + + return visitor.rewrite(func); +} diff --git a/src/program_analysis/node_replacer.cpp b/src/program_analysis/node_replacer.cpp new file mode 100644 index 00000000..ac9c313b --- /dev/null +++ b/src/program_analysis/node_replacer.cpp @@ -0,0 +1,73 @@ +#include "node_replacer.h" + +using namespace simit::ir; +using namespace simit::ir::program_analysis; + +Expr NodeReplacer::rewrite(Expr e) { + if (e.defined()) { + e.accept(this); + if (!disableLevel) { + if (exprReplacement.defined()) { + e = exprReplacement; + } else if (expr.defined()) { + e = expr; + } + } + } + exprReplacement = Expr(); + stmtReplacement = Stmt(); + funcReplacement = Func(); + expr = Expr(); + stmt = Stmt(); + func = Func(); + return e; +} + +Stmt NodeReplacer::rewrite(Stmt s) { + if (s.defined()) { + s.accept(this); + if (!disableLevel) { + if (stmtReplacement.defined()) { + s = stmtReplacement; + } else if (stmt.defined()) { + s = stmt; + } + } + } + exprReplacement = Expr(); + stmtReplacement = Stmt(); + funcReplacement = Func(); + expr = Expr(); + stmt = Stmt(); + func = Func(); + return s; +} + +Func NodeReplacer::rewrite(Func f) { + if (f.defined()) { + f.accept(this); + if (!disableLevel) { + if (funcReplacement.defined()) { + f = funcReplacement; + } else if (func.defined()) { + f = func; + } + } + } + exprReplacement = Expr(); + stmtReplacement = Stmt(); + funcReplacement = Func(); + expr = Expr(); + stmt = Stmt(); + func = Func(); + return f; +} + +void NodeReplacer::enable() { + iassert(disableLevel); + disableLevel--; +} + +void NodeReplacer::disable() { + disableLevel++; +} diff --git a/src/program_analysis/node_replacer.h b/src/program_analysis/node_replacer.h new file mode 100644 index 00000000..17e52f3d --- /dev/null +++ b/src/program_analysis/node_replacer.h @@ -0,0 +1,30 @@ +#ifndef SIMIT_IR_PROGRAM_ANALYSIS_NODE_REPLACER_H +#define SIMIT_IR_PROGRAM_ANALYSIS_NODE_REPLACER_H + +#include "ir_rewriter.h" + +namespace simit { +namespace ir { +namespace program_analysis { + +class NodeReplacer : public IRRewriter { + unsigned int disableLevel = 0; +public: + virtual Expr rewrite(Expr expr); + virtual Stmt rewrite(Stmt stmt); + virtual Func rewrite(Func func); + + void enable(); + void disable(); + +protected: + Expr exprReplacement; + Stmt stmtReplacement; + Func funcReplacement; +}; + +} +} +} + +#endif diff --git a/src/program_analysis/program_analysis.cpp b/src/program_analysis/program_analysis.cpp new file mode 100644 index 00000000..dab2e7ff --- /dev/null +++ b/src/program_analysis/program_analysis.cpp @@ -0,0 +1,10 @@ +#include "program_analysis.h" + +using namespace simit::ir; +using namespace simit::ir::program_analysis; + +Func simit::ir::program_analysis::program_analysis(Func func) { + func = CSE().rewrite(func); + + return func; +} diff --git a/src/program_analysis/program_analysis.h b/src/program_analysis/program_analysis.h new file mode 100644 index 00000000..62686322 --- /dev/null +++ b/src/program_analysis/program_analysis.h @@ -0,0 +1,21 @@ +#ifndef SIMIT_IR_PROGRAM_ANALYSIS_H +#define SIMIT_IR_PROGRAM_ANALYSIS_H + +#include "ir.h" + +namespace simit { +namespace ir { +namespace program_analysis { + +class CSE { +public: + simit::ir::Func rewrite(simit::ir::Func func); +}; + +simit::ir::Func program_analysis(simit::ir::Func func); + +} +} +} + +#endif diff --git a/src/types.h b/src/types.h index 4383fd8f..0df46ebb 100644 --- a/src/types.h +++ b/src/types.h @@ -424,6 +424,20 @@ inline bool isElementTensorType(Type type) { return !isSystemTensorType(type); } +inline bool isFixedSizeTensor(Type type) { + if (!type.isTensor()) { + return false; + } + for (auto&& dimension : type.toTensor()->getDimensions()) { + for (auto&& indexSet : dimension.getIndexSets()) { + if (indexSet.getKind() != IndexSet::Range) { + return false; + } + } + } + return true; +} + bool operator==(const Type&, const Type&); bool operator!=(const Type&, const Type&); From 72ebec764447d0c64a8f9d0198f211438930bf0c Mon Sep 17 00:00:00 2001 From: William Huang Date: Fri, 16 Jun 2017 05:38:33 -0400 Subject: [PATCH 2/2] format --- src/intrusive_ptr_hash.h | 3 --- src/ir_pattern_matching.h | 41 +++++++++++++++++++++++++-------------- src/ir_rewriter.cpp | 9 +++++++++ 3 files changed, 35 insertions(+), 18 deletions(-) diff --git a/src/intrusive_ptr_hash.h b/src/intrusive_ptr_hash.h index ba5dabff..d3ba576b 100644 --- a/src/intrusive_ptr_hash.h +++ b/src/intrusive_ptr_hash.h @@ -22,8 +22,6 @@ struct hash> { }\ }; -// SFINAE on detecting base class doesn't work with clang, need to wait for C++17. Workaround here - HASH_INTRUSIVE_PTR(simit::ir::Var) HASH_INTRUSIVE_PTR(simit::ir::IndexVar) HASH_INTRUSIVE_PTR(simit::ir::Expr) @@ -31,7 +29,6 @@ HASH_INTRUSIVE_PTR(simit::ir::Stmt) HASH_INTRUSIVE_PTR(simit::ir::Func) #undef HASH_INTRUSIVE_PTR - } #endif diff --git a/src/ir_pattern_matching.h b/src/ir_pattern_matching.h index f0623f5a..e27f1dc1 100644 --- a/src/ir_pattern_matching.h +++ b/src/ir_pattern_matching.h @@ -16,27 +16,31 @@ struct HandleCallbackArguments { } #if __cplusplus < 201700L - template ::type>::value>::type> + template ::type>::value>::type> static bool call(const T* node, U&& func) { return func(node); } - template ::type>::value>::type> + template ::type>::value>::type> static bool call(const T* node, U&& func, int* dummy = 0) { func(node); return true; } #else - template ::type>::value>::type> + template ::type>::value>::type> static bool call(const T* node, U&& func) { return func(node); } - template ::type>::value>::type> + template ::type>::value>::type> static bool call(const T* node, U&& func, int* dummy = 0) { func(node); return true; @@ -99,7 +103,8 @@ class PatternMatch { }\ \ template \ - static bool match(ExprOrStmt expr, const std::tuple& callbacks) {\ + static bool match(ExprOrStmt expr,\ + const std::tuple& callbacks) {\ if (isa(expr)) {\ const TYPE* node = to(expr);\ return HandleCallbackArguments::call(node, std::get<0>(callbacks)) &&\ @@ -123,8 +128,10 @@ class PatternMatch { return isa(expr);\ }\ \ - template \ - static bool match(ExprOrStmt expr, const std::tuple& callbacks) {\ + template \ + static bool match(ExprOrStmt expr,\ + const std::tuple& callbacks) {\ if (isa(expr)) {\ const TYPE* node = to(expr);\ return HandleCallbackArguments::call(node, std::get<0>(callbacks)) &&\ @@ -149,8 +156,10 @@ class PatternMatch { return isa(expr);\ }\ \ - template \ - static bool match(ExprOrStmt expr, const std::tuple& callbacks) {\ + template \ + static bool match(ExprOrStmt expr,\ + const std::tuple& callbacks) {\ if (isa(expr)) {\ const TYPE* node = to(expr);\ return HandleCallbackArguments::call(node, std::get<0>(callbacks)) &&\ @@ -183,7 +192,8 @@ class PatternMatch { } template - static bool match(ExprOrStmt expr, const std::tuple& callbacks) { + static bool match(ExprOrStmt expr, + const std::tuple& callbacks) { if (isa(expr)) { const UnaryExpr* node = to(expr); return HandleCallbackArguments::call(node, std::get<0>(callbacks)) && @@ -207,7 +217,8 @@ class PatternMatch { } template - static bool match(ExprOrStmt expr, const std::tuple& callbacks) { + static bool match(ExprOrStmt expr, + const std::tuple& callbacks) { if (isa(expr)) { const BinaryExpr* node = to(expr); return HandleCallbackArguments::call(node, std::get<0>(callbacks)) && diff --git a/src/ir_rewriter.cpp b/src/ir_rewriter.cpp index 9b491e28..7066f893 100644 --- a/src/ir_rewriter.cpp +++ b/src/ir_rewriter.cpp @@ -10,6 +10,9 @@ Expr IRRewriter::rewrite(Expr e) { e.accept(this); e = expr; } + else { + e = Expr(); + } expr = Expr(); stmt = Stmt(); func = Func(); @@ -25,6 +28,9 @@ Stmt IRRewriter::rewrite(Stmt s) { } s = stmt; } + else { + s = Stmt(); + } expr = Expr(); stmt = Stmt(); func = Func(); @@ -36,6 +42,9 @@ Func IRRewriter::rewrite(Func f) { f.accept(this); f = func; } + else { + f = Func(); + } expr = Expr(); stmt = Stmt(); func = Func();