Skip to content

Commit

Permalink
[NNC] Add loop unroll transformation (pytorch#42465)
Browse files Browse the repository at this point in the history
Summary:
Unroll a loop with constant boundaries, replacing it with multiple
instances of the loop body. For example:

```
for x in 0..3:
  A[x] = x*2
```

becomes:

```
A[0] = 0
A[1] = 2
A[2] = 4
```

Pull Request resolved: pytorch#42465

Test Plan: `test_tensorexpr` unit tests.

Reviewed By: agolynski

Differential Revision: D22914418

Pulled By: asuhan

fbshipit-source-id: 72ca10d7c0b1ac7f9a3688ac872bd94a1c53dc51
  • Loading branch information
asuhan authored and facebook-github-bot committed Aug 6, 2020
1 parent 3d46e02 commit 1848b43
Show file tree
Hide file tree
Showing 4 changed files with 209 additions and 0 deletions.
169 changes: 169 additions & 0 deletions test/cpp/tensorexpr/test_loopnest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1528,5 +1528,174 @@ void testOuterLoopVectorization() {
ASSERT_EQ(dynamic_cast<For*>(for_body->front()), nullptr);
}

namespace {

std::string constantUpperBoundLoopIR(int upper_bound_val) {
KernelScope kernel_scope;
ExprHandle upper_bound(upper_bound_val);
Tensor* A = Compute(
"A", {{upper_bound, "x"}}, [&](const VarHandle& x) { return x * 2; });
LoopNest l({A});
std::vector<For*> loops = l.getLoopStmtsFor(A);
Stmt* unrolled = nullptr;
LoopNest::unroll(loops[0], &unrolled);
std::ostringstream oss;
oss << *unrolled;
return oss.str();
}

} // namespace

void testUnroll() {
const std::string actual = constantUpperBoundLoopIR(3);
const std::string& verification_pattern =
R"IR(
# CHECK: A[0] = 0;
# CHECK: A[1] = 2;
# CHECK: A[2] = 4)IR";

torch::jit::testing::FileCheck().run(verification_pattern, actual);
}

void testUnrollOuter() {
KernelScope kernel_scope;
ExprHandle outer_bound(3);
ExprHandle inner_bound(4);
Tensor* A = Compute(
"A",
{{outer_bound, "x"}, {inner_bound, "y"}},
[&](const VarHandle& x, const VarHandle& y) { return x + y; });
LoopNest l({A});
std::vector<For*> loops = l.getLoopStmtsFor(A);
Stmt* unrolled = nullptr;
LoopNest::unroll(loops[0], &unrolled);
const std::string& verification_pattern =
R"IR(
# CHECK: for (int y = 0; y < 4; y++) {
# CHECK: A[0, y] = y;
# CHECK: }
# CHECK: for (int y = 0; y < 4; y++) {
# CHECK: A[1, y] = y + 1;
# CHECK: }
# CHECK: for (int y = 0; y < 4; y++) {
# CHECK: A[2, y] = y + 2;
# CHECK: })IR";

std::ostringstream oss;
oss << *unrolled;
torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
}

void testUnrollInner() {
KernelScope kernel_scope;
ExprHandle outer_bound(3);
ExprHandle inner_bound(4);
Tensor* A = Compute(
"A",
{{outer_bound, "x"}, {inner_bound, "y"}},
[&](const VarHandle& x, const VarHandle& y) { return x + y; });
LoopNest l({A});
std::vector<For*> loops = l.getLoopStmtsFor(A);
Stmt* unrolled = nullptr;
LoopNest::unroll(
static_cast<For*>(loops[0]->body()->stmts().front()), &unrolled);
const std::string& verification_pattern =
R"IR(
# CHECK: for (int x = 0; x < 3; x++) {
# CHECK: A[x, 0] = x;
# CHECK: A[x, 1] = x + 1;
# CHECK: A[x, 2] = x + 2;
# CHECK: A[x, 3] = x + 3;
# CHECK: })IR";

std::ostringstream oss;
oss << *loops[0];
torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
}

void testUnrollMultipleStatements() {
KernelScope kernel_scope;
const int kTotalSize = 3;
BufHandle a_buf("A", {ExprHandle(kTotalSize)}, kInt);
BufHandle b_buf("B", {ExprHandle(kTotalSize)}, kInt);

VarHandle x("x", kInt);
auto f = For::make(
x,
0,
kTotalSize,
Block::make({Store::make(a_buf, {x}, x * 2),
Store::make(b_buf, {x}, Load::make(a_buf, {x}, 1))}));
Block::make({f});
Stmt* unrolled = nullptr;
LoopNest::unroll(f, &unrolled);
std::ostringstream oss;
oss << *unrolled;
const std::string& verification_pattern =
R"IR(
# CHECK: A[0] = 0;
# CHECK: B[0] = A[0];
# CHECK: A[1] = 2;
# CHECK: B[1] = A[1];
# CHECK: A[2] = 4
# CHECK: B[2] = A[2];)IR";

torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
}

void testUnrollEmpty() {
const std::string actual = constantUpperBoundLoopIR(0);
const std::string& verification_pattern = R"IR(
# CHECK-NOT: A[
)IR";

torch::jit::testing::FileCheck().run(verification_pattern, actual);
}

void testNoUnroll() {
KernelScope kernel_scope;
VarHandle upper_bound("N", kInt);
Tensor* A = Compute(
"A", {{upper_bound, "x"}}, [&](const VarHandle& x) { return x * 2; });
LoopNest l({A});
std::vector<For*> loops = l.getLoopStmtsFor(A);
Stmt* unrolled = nullptr;
ASSERT_THROWS_WITH(
LoopNest::unroll(loops[0], &unrolled), "non-constant loop");
}

void testUnrollWithVarMap() {
KernelScope kernel_scope;
const int kTotalSize = 3;
BufHandle a_buf("A", {ExprHandle(kTotalSize)}, kInt);
BufHandle b_buf("B", {ExprHandle(kTotalSize)}, kInt);

VarHandle e("e", kInt);
VarHandle x("x", kInt);
auto f = For::make(
x,
0,
kTotalSize,
Block::make(
{{e.node(), new IntImm(7)}},
{Store::make(a_buf, {x}, e), Store::make(b_buf, {x}, e + 1)}));
Block::make({f});
Stmt* unrolled = nullptr;
LoopNest::unroll(f, &unrolled);
std::ostringstream oss;
oss << *unrolled;
const std::string& verification_pattern =
R"IR(
# CHECK: int e = 7;
# CHECK: A[0] = e;
# CHECK: B[0] = e + 1;
# CHECK: A[1] = e;
# CHECK: B[1] = e + 1;
# CHECK: A[2] = e;
# CHECK: B[2] = e + 1;)IR";

torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
}

} // namespace jit
} // namespace torch
7 changes: 7 additions & 0 deletions test/cpp/tensorexpr/tests.h
Original file line number Diff line number Diff line change
Expand Up @@ -202,6 +202,13 @@ namespace jit {
_(LoopNestReorderLongStringFull) \
_(LoopNestReorderInternalLoopNest) \
_(OuterLoopVectorization) \
_(Unroll) \
_(UnrollOuter) \
_(UnrollInner) \
_(UnrollMultipleStatements) \
_(UnrollEmpty) \
_(NoUnroll) \
_(UnrollWithVarMap) \
_(Kernel_1) \
_(Kernel_2) \
_(Kernel_3) \
Expand Down
32 changes: 32 additions & 0 deletions torch/csrc/jit/tensorexpr/loopnest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1165,6 +1165,38 @@ void LoopNest::reorderAxis(For* a, For* b) {
}
} // namespace tensorexpr

void LoopNest::unroll(For* f, Stmt** unrolled) {
Block* p = dynamic_cast<Block*>(f->get_parent());
if (!f) {
throw malformed_input("unroll attempted on null loop");
} else if (!p) {
throw malformed_input("unroll attempted on loop with no parent");
}

if (!f->start()->isConstant()) {
throw std::runtime_error("Can't unroll due to non-constant loop start!");
}
if (!f->stop()->isConstant()) {
throw std::runtime_error("Can't unroll due to non-constant loop stop!");
}

std::vector<Stmt*> unrolled_stmts;
int start_val = immediateAs<int>(f->start());
int stop_val = immediateAs<int>(f->stop());
for (int current = start_val; current < stop_val; ++current) {
for (const auto stmt : f->body()->stmts()) {
auto stmt_copy = Stmt::clone(stmt);
unrolled_stmts.push_back(Substitute(
stmt_copy,
{{f->var(), getImmediateByType(f->var()->dtype(), current)}}));
}
}
*unrolled = new Block(f->body()->varBindings(), unrolled_stmts);
*unrolled = IRSimplifier::simplify(*unrolled);

p->replace_stmt(f, *unrolled);
}

std::vector<For*> LoopNest::getLoopStmtsFor(Tensor* t) const {
std::vector<For*> result;
Stmt* cur_stmt = tensor_to_stmt_.at(t);
Expand Down
1 change: 1 addition & 0 deletions torch/csrc/jit/tensorexpr/loopnest.h
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ class TORCH_API LoopNest {
void splitWithTail(For* f, int factor, For** outer, For** inner, For** tail);
void splitWithMask(For* f, int factor, For** outer, For** inner);
void reorderAxis(For* a, For* b);
static void unroll(For* f, Stmt** unrolled);

void setGPUBlockIndex(For* f, int idx);
void setGPUThreadIndex(For* f, int idx);
Expand Down

0 comments on commit 1848b43

Please sign in to comment.