diff --git a/test/cpp/tensorexpr/test_loopnest.cpp b/test/cpp/tensorexpr/test_loopnest.cpp index 1af4391546d66..a3f9698683817 100644 --- a/test/cpp/tensorexpr/test_loopnest.cpp +++ b/test/cpp/tensorexpr/test_loopnest.cpp @@ -1528,5 +1528,174 @@ void testOuterLoopVectorization() { ASSERT_EQ(dynamic_cast(for_body->front()), nullptr); } +namespace { + +std::string constantUpperBoundLoopIR(int upper_bound_val) { + KernelScope kernel_scope; + ExprHandle upper_bound(upper_bound_val); + Tensor* A = Compute( + "A", {{upper_bound, "x"}}, [&](const VarHandle& x) { return x * 2; }); + LoopNest l({A}); + std::vector loops = l.getLoopStmtsFor(A); + Stmt* unrolled = nullptr; + LoopNest::unroll(loops[0], &unrolled); + std::ostringstream oss; + oss << *unrolled; + return oss.str(); +} + +} // namespace + +void testUnroll() { + const std::string actual = constantUpperBoundLoopIR(3); + const std::string& verification_pattern = + R"IR( +# CHECK: A[0] = 0; +# CHECK: A[1] = 2; +# CHECK: A[2] = 4)IR"; + + torch::jit::testing::FileCheck().run(verification_pattern, actual); +} + +void testUnrollOuter() { + KernelScope kernel_scope; + ExprHandle outer_bound(3); + ExprHandle inner_bound(4); + Tensor* A = Compute( + "A", + {{outer_bound, "x"}, {inner_bound, "y"}}, + [&](const VarHandle& x, const VarHandle& y) { return x + y; }); + LoopNest l({A}); + std::vector loops = l.getLoopStmtsFor(A); + Stmt* unrolled = nullptr; + LoopNest::unroll(loops[0], &unrolled); + const std::string& verification_pattern = + R"IR( +# CHECK: for (int y = 0; y < 4; y++) { +# CHECK: A[0, y] = y; +# CHECK: } +# CHECK: for (int y = 0; y < 4; y++) { +# CHECK: A[1, y] = y + 1; +# CHECK: } +# CHECK: for (int y = 0; y < 4; y++) { +# CHECK: A[2, y] = y + 2; +# CHECK: })IR"; + + std::ostringstream oss; + oss << *unrolled; + torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); +} + +void testUnrollInner() { + KernelScope kernel_scope; + ExprHandle outer_bound(3); + ExprHandle inner_bound(4); + Tensor* A = Compute( + "A", + {{outer_bound, "x"}, {inner_bound, "y"}}, + [&](const VarHandle& x, const VarHandle& y) { return x + y; }); + LoopNest l({A}); + std::vector loops = l.getLoopStmtsFor(A); + Stmt* unrolled = nullptr; + LoopNest::unroll( + static_cast(loops[0]->body()->stmts().front()), &unrolled); + const std::string& verification_pattern = + R"IR( +# CHECK: for (int x = 0; x < 3; x++) { +# CHECK: A[x, 0] = x; +# CHECK: A[x, 1] = x + 1; +# CHECK: A[x, 2] = x + 2; +# CHECK: A[x, 3] = x + 3; +# CHECK: })IR"; + + std::ostringstream oss; + oss << *loops[0]; + torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); +} + +void testUnrollMultipleStatements() { + KernelScope kernel_scope; + const int kTotalSize = 3; + BufHandle a_buf("A", {ExprHandle(kTotalSize)}, kInt); + BufHandle b_buf("B", {ExprHandle(kTotalSize)}, kInt); + + VarHandle x("x", kInt); + auto f = For::make( + x, + 0, + kTotalSize, + Block::make({Store::make(a_buf, {x}, x * 2), + Store::make(b_buf, {x}, Load::make(a_buf, {x}, 1))})); + Block::make({f}); + Stmt* unrolled = nullptr; + LoopNest::unroll(f, &unrolled); + std::ostringstream oss; + oss << *unrolled; + const std::string& verification_pattern = + R"IR( +# CHECK: A[0] = 0; +# CHECK: B[0] = A[0]; +# CHECK: A[1] = 2; +# CHECK: B[1] = A[1]; +# CHECK: A[2] = 4 +# CHECK: B[2] = A[2];)IR"; + + torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); +} + +void testUnrollEmpty() { + const std::string actual = constantUpperBoundLoopIR(0); + const std::string& verification_pattern = R"IR( +# CHECK-NOT: A[ + )IR"; + + torch::jit::testing::FileCheck().run(verification_pattern, actual); +} + +void testNoUnroll() { + KernelScope kernel_scope; + VarHandle upper_bound("N", kInt); + Tensor* A = Compute( + "A", {{upper_bound, "x"}}, [&](const VarHandle& x) { return x * 2; }); + LoopNest l({A}); + std::vector loops = l.getLoopStmtsFor(A); + Stmt* unrolled = nullptr; + ASSERT_THROWS_WITH( + LoopNest::unroll(loops[0], &unrolled), "non-constant loop"); +} + +void testUnrollWithVarMap() { + KernelScope kernel_scope; + const int kTotalSize = 3; + BufHandle a_buf("A", {ExprHandle(kTotalSize)}, kInt); + BufHandle b_buf("B", {ExprHandle(kTotalSize)}, kInt); + + VarHandle e("e", kInt); + VarHandle x("x", kInt); + auto f = For::make( + x, + 0, + kTotalSize, + Block::make( + {{e.node(), new IntImm(7)}}, + {Store::make(a_buf, {x}, e), Store::make(b_buf, {x}, e + 1)})); + Block::make({f}); + Stmt* unrolled = nullptr; + LoopNest::unroll(f, &unrolled); + std::ostringstream oss; + oss << *unrolled; + const std::string& verification_pattern = + R"IR( +# CHECK: int e = 7; +# CHECK: A[0] = e; +# CHECK: B[0] = e + 1; +# CHECK: A[1] = e; +# CHECK: B[1] = e + 1; +# CHECK: A[2] = e; +# CHECK: B[2] = e + 1;)IR"; + + torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); +} + } // namespace jit } // namespace torch diff --git a/test/cpp/tensorexpr/tests.h b/test/cpp/tensorexpr/tests.h index bc7c76cd0ddb7..5d9f4a7e1703d 100644 --- a/test/cpp/tensorexpr/tests.h +++ b/test/cpp/tensorexpr/tests.h @@ -202,6 +202,13 @@ namespace jit { _(LoopNestReorderLongStringFull) \ _(LoopNestReorderInternalLoopNest) \ _(OuterLoopVectorization) \ + _(Unroll) \ + _(UnrollOuter) \ + _(UnrollInner) \ + _(UnrollMultipleStatements) \ + _(UnrollEmpty) \ + _(NoUnroll) \ + _(UnrollWithVarMap) \ _(Kernel_1) \ _(Kernel_2) \ _(Kernel_3) \ diff --git a/torch/csrc/jit/tensorexpr/loopnest.cpp b/torch/csrc/jit/tensorexpr/loopnest.cpp index 8c205c119ae68..fa5d045e52880 100644 --- a/torch/csrc/jit/tensorexpr/loopnest.cpp +++ b/torch/csrc/jit/tensorexpr/loopnest.cpp @@ -1165,6 +1165,38 @@ void LoopNest::reorderAxis(For* a, For* b) { } } // namespace tensorexpr +void LoopNest::unroll(For* f, Stmt** unrolled) { + Block* p = dynamic_cast(f->get_parent()); + if (!f) { + throw malformed_input("unroll attempted on null loop"); + } else if (!p) { + throw malformed_input("unroll attempted on loop with no parent"); + } + + if (!f->start()->isConstant()) { + throw std::runtime_error("Can't unroll due to non-constant loop start!"); + } + if (!f->stop()->isConstant()) { + throw std::runtime_error("Can't unroll due to non-constant loop stop!"); + } + + std::vector unrolled_stmts; + int start_val = immediateAs(f->start()); + int stop_val = immediateAs(f->stop()); + for (int current = start_val; current < stop_val; ++current) { + for (const auto stmt : f->body()->stmts()) { + auto stmt_copy = Stmt::clone(stmt); + unrolled_stmts.push_back(Substitute( + stmt_copy, + {{f->var(), getImmediateByType(f->var()->dtype(), current)}})); + } + } + *unrolled = new Block(f->body()->varBindings(), unrolled_stmts); + *unrolled = IRSimplifier::simplify(*unrolled); + + p->replace_stmt(f, *unrolled); +} + std::vector LoopNest::getLoopStmtsFor(Tensor* t) const { std::vector result; Stmt* cur_stmt = tensor_to_stmt_.at(t); diff --git a/torch/csrc/jit/tensorexpr/loopnest.h b/torch/csrc/jit/tensorexpr/loopnest.h index f9b024d39e632..aeb0e3744b5e5 100644 --- a/torch/csrc/jit/tensorexpr/loopnest.h +++ b/torch/csrc/jit/tensorexpr/loopnest.h @@ -39,6 +39,7 @@ class TORCH_API LoopNest { void splitWithTail(For* f, int factor, For** outer, For** inner, For** tail); void splitWithMask(For* f, int factor, For** outer, For** inner); void reorderAxis(For* a, For* b); + static void unroll(For* f, Stmt** unrolled); void setGPUBlockIndex(For* f, int idx); void setGPUThreadIndex(For* f, int idx);