Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Copy Propagation #107

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ add_subdirectory(backend)
add_subdirectory(frontend)
add_subdirectory(lower)
add_subdirectory(visualizer)
add_subdirectory(program_analysis)

add_definitions(${SIMIT_DEFINITIONS})
include_directories(${SIMIT_INCLUDE_DIRS})
Expand Down
5 changes: 2 additions & 3 deletions src/backend/backend_visitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,8 @@ void BackendVisitorBase::compile(const ir::GPUKernel& kernel) {
#endif

void BackendVisitorBase::compile(const ir::Block& op) {
op.first.accept(this);
if (op.rest.defined()) {
op.rest.accept(this);
for (Stmt stmt : op.stmts) {
stmt.accept(this);
}
}

Expand Down
34 changes: 34 additions & 0 deletions src/intrusive_ptr_hash.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
#ifndef SIMIT_INTRUSIVE_PTR_HASH_H
#define SIMIT_INTRUSIVE_PTR_HASH_H

#include <functional>

#include "intrusive_ptr.h"
#include "ir.h"

namespace std {

template <typename T>
struct hash<simit::util::IntrusivePtr<T>> {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this actually used anywhere?

size_t operator()(const simit::util::IntrusivePtr<T>& x) const {
return (size_t) x.ptr;
}
};

#define HASH_INTRUSIVE_PTR(type) template <>\
struct hash<type> {\
size_t operator()(const type& x) const {\
return (size_t) x.ptr;\
}\
};

HASH_INTRUSIVE_PTR(simit::ir::Var)
HASH_INTRUSIVE_PTR(simit::ir::IndexVar)
HASH_INTRUSIVE_PTR(simit::ir::Expr)
HASH_INTRUSIVE_PTR(simit::ir::Stmt)
HASH_INTRUSIVE_PTR(simit::ir::Func)

#undef HASH_INTRUSIVE_PTR
}

#endif
23 changes: 17 additions & 6 deletions src/ir.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -837,17 +837,28 @@ Stmt Block::make(Stmt first, Stmt rest) {
}

Block *node = new Block;
node->first = first;
node->rest = rest;
if (isa<Block>(first)) {
const Block* block = to<Block>(first);
node->stmts = block->stmts;
} else {
node->stmts = {first};
}

if (rest.defined()) {
if (isa<Block>(rest)) {
const Block* block = to<Block>(rest);
node->stmts.insert(node->stmts.end(), block->stmts.begin(), block->stmts.end());
} else {
node->stmts.push_back(rest);
}
}
return node;
}

Stmt Block::make(std::vector<Stmt> stmts) {
iassert(stmts.size() > 0) << "Empty block";
Stmt node;
for (size_t i=stmts.size(); i>0; --i) {
node = Block::make(stmts[i-1], node);
}
Block *node = new Block;
node->stmts = stmts;
return node;
}

Expand Down
2 changes: 1 addition & 1 deletion src/ir.h
Original file line number Diff line number Diff line change
Expand Up @@ -439,7 +439,7 @@ struct Kernel : public StmtNode {
};

struct Block : public StmtNode {
Stmt first, rest;
vector<Stmt> stmts;
static Stmt make(Stmt first, Stmt rest);
static Stmt make(std::vector<Stmt> stmts);
void accept(IRVisitorStrict *v) const {v->visit((const Block*)this);}
Expand Down
282 changes: 282 additions & 0 deletions src/ir_pattern_matching.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,282 @@
#ifndef SIMIT_IR_PATTERN_MATCHING_H
#define SIMIT_IR_PATTERN_MATCHING_H

#include "ir.h"
#include <numeric>
#include <functional>
#include <type_traits>

namespace simit {
namespace ir {

struct HandleCallbackArguments {
template <typename T>
static bool call(const T* node, std::nullptr_t) {
return true;
}

#if __cplusplus < 201700L
template <typename T, typename U, typename =
typename std::enable_if<std::is_same<bool,
typename std::result_of<U(const T*)>::type>::value>::type>
static bool call(const T* node, U&& func) {
return func(node);
}

template <typename T, typename U, typename =
typename std::enable_if<std::is_same<void,
typename std::result_of<U(const T*)>::type>::value>::type>
static bool call(const T* node, U&& func, int* dummy = 0) {
func(node);
return true;
}
#else
template <typename T, typename U, typename =
typename std::enable_if<std::is_same<bool,
typename std::invoke_result<U(const T*)>::type>::value>::type>
static bool call(const T* node, U&& func) {
return func(node);
}

template <typename T, typename U, typename =
typename std::enable_if<std::is_same<void,
typename std::invoke_result<U(const T*)>::type>::value>::type>
static bool call(const T* node, U&& func, int* dummy = 0) {
func(node);
return true;
}
#endif
};

template <typename T>
class PatternMatch;

template <typename T>
class PatternMatch<T*> : public PatternMatch<T> {
};

template <>
class PatternMatch<void> {
public:
template <typename ExprOrStmt, typename Callback>
static bool match(ExprOrStmt expr, Callback&&) {
return true;
}
};

#define GENERATE_PATTERNMATCH_0(TYPE)\
template <>\
class PatternMatch<TYPE> {\
public:\
template <typename ExprOrStmt>\
static bool match(ExprOrStmt expr, std::nullptr_t) {\
return isa<TYPE>(expr);\
}\
\
template <typename ExprOrStmt>\
static bool match(ExprOrStmt expr, const std::tuple<std::nullptr_t>&) {\
return isa<TYPE>(expr);\
}\
\
template <typename ExprOrStmt, typename Callback>\
static bool match(ExprOrStmt expr, const std::tuple<Callback>& callbacks) {\
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Might be better if callbacks is just of type Callback rather than a tuple.

if (isa<TYPE>(expr)) {\
const TYPE* node = to<TYPE>(expr);\
return HandleCallbackArguments::call(node, std::get<0>(callbacks));\
}\
return false;\
}\
};

#define GENERATE_PATTERNMATCH_1(TYPE, MEMBER1)\
template <typename T1>\
class PatternMatch<TYPE(T1)> {\
public:\
template <typename ExprOrStmt>\
static bool match(ExprOrStmt expr, std::nullptr_t) {\
return isa<TYPE>(expr);\
}\
\
template <typename ExprOrStmt>\
static bool match(ExprOrStmt expr, const std::tuple<std::nullptr_t>&) {\
return isa<TYPE>(expr);\
}\
\
template <typename ExprOrStmt, typename Callback, typename Sub1>\
static bool match(ExprOrStmt expr,\
const std::tuple<Callback, Sub1>& callbacks) {\
if (isa<TYPE>(expr)) {\
const TYPE* node = to<TYPE>(expr);\
return HandleCallbackArguments::call(node, std::get<0>(callbacks)) &&\
PatternMatch<T1>::match(node->MEMBER1, std::get<1>(callbacks));\
}\
return false;\
}\
};

#define GENERATE_PATTERNMATCH_2(TYPE, MEMBER1, MEMBER2)\
template <typename T1, typename T2>\
class PatternMatch<TYPE(T1, T2)> {\
public:\
template <typename ExprOrStmt>\
static bool match(ExprOrStmt expr, std::nullptr_t) {\
return isa<TYPE>(expr);\
}\
\
template <typename ExprOrStmt>\
static bool match(ExprOrStmt expr, const std::tuple<std::nullptr_t>&) {\
return isa<TYPE>(expr);\
}\
\
template <typename ExprOrStmt, typename Callback,\
typename Sub1, typename Sub2>\
static bool match(ExprOrStmt expr,\
const std::tuple<Callback,Sub1, Sub2>& callbacks) {\
if (isa<TYPE>(expr)) {\
const TYPE* node = to<TYPE>(expr);\
return HandleCallbackArguments::call(node, std::get<0>(callbacks)) &&\
PatternMatch<T1>::match(node->MEMBER1, std::get<1>(callbacks)) &&\
PatternMatch<T2>::match(node->MEMBER2, std::get<2>(callbacks));\
}\
return false;\
}\
};

#define GENERATE_PATTERNMATCH_3(TYPE, MEMBER1, MEMBER2, MEMBER3)\
template <typename T1, typename T2, typename T3>\
class PatternMatch<TYPE(T1, T2, T3)> {\
public:\
template <typename ExprOrStmt>\
static bool match(ExprOrStmt expr, std::nullptr_t) {\
return isa<TYPE>(expr);\
}\
\
template <typename ExprOrStmt>\
static bool match(ExprOrStmt expr, const std::tuple<std::nullptr_t>&) {\
return isa<TYPE>(expr);\
}\
\
template <typename ExprOrStmt, typename Callback,\
typename Sub1, typename Sub2, typename Sub3>\
static bool match(ExprOrStmt expr,\
const std::tuple<Callback, Sub1, Sub2, Sub3>& callbacks) {\
if (isa<TYPE>(expr)) {\
const TYPE* node = to<TYPE>(expr);\
return HandleCallbackArguments::call(node, std::get<0>(callbacks)) &&\
PatternMatch<T1>::match(node->MEMBER1, std::get<1>(callbacks)) &&\
PatternMatch<T2>::match(node->MEMBER2, std::get<2>(callbacks)) &&\
PatternMatch<T3>::match(node->MEMBER3, std::get<3>(callbacks));\
}\
return false;\
}\
};

GENERATE_PATTERNMATCH_0(Literal)
GENERATE_PATTERNMATCH_0(VarExpr)
GENERATE_PATTERNMATCH_2(Load, buffer, index)
GENERATE_PATTERNMATCH_1(FieldRead, elementOrSet)
GENERATE_PATTERNMATCH_0(Length)
GENERATE_PATTERNMATCH_1(IndexRead, edgeSet)

template <typename T1>
class PatternMatch<UnaryExpr*(T1)> {
public:
template <typename ExprOrStmt>
static bool match(ExprOrStmt expr, std::nullptr_t) {
return isa<UnaryExpr>(expr);
}

template <typename ExprOrStmt>
static bool match(ExprOrStmt expr, const std::tuple<std::nullptr_t>&) {
return isa<UnaryExpr>(expr);
}

template <typename ExprOrStmt, typename Callback, typename Sub1>
static bool match(ExprOrStmt expr,
const std::tuple<Callback, Sub1>& callbacks) {
if (isa<UnaryExpr>(expr)) {
const UnaryExpr* node = to<UnaryExpr>(expr);
return HandleCallbackArguments::call(node, std::get<0>(callbacks)) &&
PatternMatch<T1>::match(node->a, std::get<1>(callbacks));
}
return false;
}
};

template <typename T1, typename T2>
class PatternMatch<BinaryExpr*(T1, T2)> {
public:
template <typename ExprOrStmt>
static bool match(ExprOrStmt expr, std::nullptr_t) {
return isa<BinaryExpr>(expr);
}

template <typename ExprOrStmt>
static bool match(ExprOrStmt expr, const std::tuple<std::nullptr_t>&) {
return isa<BinaryExpr>(expr);
}

template <typename ExprOrStmt, typename Callback, typename Sub1, typename Sub2>
static bool match(ExprOrStmt expr,
const std::tuple<Callback,Sub1, Sub2>& callbacks) {
if (isa<BinaryExpr>(expr)) {
const BinaryExpr* node = to<BinaryExpr>(expr);
return HandleCallbackArguments::call(node, std::get<0>(callbacks)) &&
PatternMatch<T1>::match(node->a, std::get<1>(callbacks)) &&
PatternMatch<T2>::match(node->b, std::get<2>(callbacks));
}
return false;
}
};

GENERATE_PATTERNMATCH_1(Neg, a)
GENERATE_PATTERNMATCH_2(Add, a, b)
GENERATE_PATTERNMATCH_2(Sub, a, b)
GENERATE_PATTERNMATCH_2(Mul, a, b)
GENERATE_PATTERNMATCH_2(Div, a, b)
GENERATE_PATTERNMATCH_2(Rem, a, b)

GENERATE_PATTERNMATCH_1(Not, a)
GENERATE_PATTERNMATCH_2(Eq, a, b)
GENERATE_PATTERNMATCH_2(Ne, a, b)
GENERATE_PATTERNMATCH_2(Gt, a, b)
GENERATE_PATTERNMATCH_2(Lt, a, b)
GENERATE_PATTERNMATCH_2(Ge, a, b)
GENERATE_PATTERNMATCH_2(Le, a, b)
GENERATE_PATTERNMATCH_2(And, a, b)
GENERATE_PATTERNMATCH_2(Or, a, b)
GENERATE_PATTERNMATCH_2(Xor, a, b)

GENERATE_PATTERNMATCH_0(VarDecl)
GENERATE_PATTERNMATCH_1(AssignStmt, value)
GENERATE_PATTERNMATCH_3(Store, buffer, index, value)
GENERATE_PATTERNMATCH_2(FieldWrite, elementOrSet, value)

// call

GENERATE_PATTERNMATCH_1(Scope, scopedStmt)
GENERATE_PATTERNMATCH_3(IfThenElse, condition, thenBody, elseBody)
GENERATE_PATTERNMATCH_3(ForRange, start, end, body)
GENERATE_PATTERNMATCH_1(For, body)
GENERATE_PATTERNMATCH_2(While, condition, body)
GENERATE_PATTERNMATCH_1(Kernel, body)

// block

GENERATE_PATTERNMATCH_1(Print, expr)
GENERATE_PATTERNMATCH_1(Comment,commentedStmt)
GENERATE_PATTERNMATCH_0(Pass)

GENERATE_PATTERNMATCH_2(UnnamedTupleRead, tuple, index)
GENERATE_PATTERNMATCH_1(NamedTupleRead, tuple)
// setread
// tensorread
//GENERATE_PATTERNMATCH_(TensorWrite)
GENERATE_PATTERNMATCH_1(IndexedTensor, tensor)
GENERATE_PATTERNMATCH_1(IndexExpr, value)
//GENERATE_PATTERNMATCH_(Map)

}
}

#endif
6 changes: 3 additions & 3 deletions src/ir_printer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -468,10 +468,10 @@ void IRPrinter::visit(const Kernel *op) {
}

void IRPrinter::visit(const Block *op) {
print(op->first);
if (op->rest.defined()) {
print(op->stmts.front());
for (size_t i = 1; i < op->stmts.size(); i++) {
os << endl;
print(op->rest);
print(op->stmts[i]);
}
}

Expand Down
Loading