Skip to content

Commit

Permalink
[TensorExpr] Make Load and Store multi-dimensional. (pytorch#35800)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: pytorch#35800

This PR includes the following changes:
* Introduce a new `Expr` type `Buf`: it plays a similar to `Var` role, but also has dimensions.
* Use the new `Buf` class in `Store` and `Load` instead of `Var` for specifying where to store to or load from. `Buf` contains the dimensions info of the buffer we're loading/storing to and hence we are able to keep N-d indexes without flattening them into a 1-d index ([x,y] vs [x+y*W]).
* Flattening of the indexes is now a separate pass that is executed in `LoopNest::prepareForCodegen` - backends still expect indexes to be flattened, and this PR preserves that.
* `Tensor` now contains a `Buf` instead of `Var`, and thus Tensor now has the dimensions info (previously it was a property of a `Function`, not a `Tensor`). This brings us closer to Tensor being a combination of Buffer + Function, where Buffer specifies iteration domain and the Function defines a computation.

TODOs:
* Consider merging `Buffer` with `Buf` or `BufHandle`. It seems that we don't need all of them.
* Harden the logic of how we create buffers in fuser pass. Currently it seems that sometimes we don't set dimensions.
* Use `Buf` in `Allocate` and `Free`.
* Make it clearer that `Function` doesn't "own" dimensions info and that dimensions are a property of a Tensor, not a Function.

Differential Revision: D20789005

Test Plan: Imported from OSS

Reviewed By: zheng-xq

Pulled By: ZolotukhinM

fbshipit-source-id: e04188d1d297f195f1c46669c614557d6bb6cde4
  • Loading branch information
Mikhail Zolotukhin authored and facebook-github-bot committed Apr 2, 2020
1 parent 676fc92 commit 3ef5ff6
Show file tree
Hide file tree
Showing 28 changed files with 901 additions and 667 deletions.
396 changes: 198 additions & 198 deletions test/cpp/tensorexpr/test_aten.cpp

Large diffs are not rendered by default.

10 changes: 7 additions & 3 deletions test/cpp/tensorexpr/test_cuda.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ void testCudaTestVectorAdd01_impl() {
std::vector<For*> loops = l.getLoopStmtsFor(c);
l.setGPUBlockIndex(loops[1], 0);
l.setGPUThreadIndex(loops[2], 0);
l.prepareForCodegen();
Stmt* stmt = l.root_stmt();
CudaCodeGen cuda_cg(stmt, c, a_buf, b_buf);
const int N = block_count * block_size * num_iter;
Expand Down Expand Up @@ -113,6 +114,7 @@ static void testCudaTestVectorAdd02_impl(int N, int block_size) {
l.splitWithMask(loops[0], block_size, &n_outer, &n_inner);
l.setGPUBlockIndex(n_outer, 0);
l.setGPUThreadIndex(n_inner, 0);
l.prepareForCodegen();
Stmt* stmt = l.root_stmt();
CudaCodeGen cuda_cg(stmt, c, a_buf, b_buf);
PaddedBuffer<float> a_v(N);
Expand Down Expand Up @@ -161,13 +163,14 @@ void testCudaDynamicShape2D() {
auto testWithSize = [](int32_t M, int32_t N) {
VarHandle m("m", kInt);
VarHandle n("n", kInt);
Buffer a(VarHandle("a", kHandle), kFloat, {m, n});
Buffer b(VarHandle("b", kHandle), kFloat, {m, n});
Buffer a(BufHandle("a", {m, n}), kFloat);
Buffer b(BufHandle("b", {m, n}), kFloat);
Tensor* c = Compute(
"c", {{m, "m"}, {n, "n"}}, [&](const VarHandle& i, const VarHandle& j) {
return a(i, j) + b(i, j);
});
LoopNest l({c});
l.prepareForCodegen();
Stmt* s = l.root_stmt();
CudaCodeGen cg(s, {a, b, c, m, n});

Expand Down Expand Up @@ -237,6 +240,7 @@ void testCudaTestRand01() {
std::vector<For*> loops = l.getLoopStmtsFor(c);
l.setGPUBlockIndex(loops[1], 0);
l.setGPUThreadIndex(loops[2], 0);
l.prepareForCodegen();
Stmt* stmt = l.root_stmt();
CudaCodeGen cuda_cg(stmt, c);
const int N = block_count * block_size * num_iter;
Expand Down Expand Up @@ -280,7 +284,7 @@ void testCudaDynamicShapeSplit() {
KernelScope ks;
constexpr int N = 4096;
VarHandle n("n", kInt);
Buffer a(VarHandle("a", kHandle), kFloat, {n});
Buffer a(BufHandle("a", {n}), kFloat);
Tensor* b =
Compute("b", {{n, "n"}}, [&](const VarHandle& i) { return a(i) * 2.0f; });
LoopNest l({b});
Expand Down
44 changes: 22 additions & 22 deletions test/cpp/tensorexpr/test_expr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -69,9 +69,9 @@ void testExprLetStmtTest01() {
Buffer a_buf("a", kFloat, {1});
Buffer b_buf("b", kFloat, {1});

ExprHandle load_a = Load::make(a_buf, 0, 1);
ExprHandle load_a = Load::make(a_buf, {0}, 1);
VarHandle var = VarHandle("v", kFloat);
Stmt* store_b = Store::make(b_buf, 0, var, 1);
Stmt* store_b = Store::make(b_buf, {0}, var, 1);
Stmt* let_store = LetStmt::make(var, load_a, store_b);
SimpleIREvaluator eval(let_store, a_buf, b_buf);

Expand Down Expand Up @@ -182,9 +182,9 @@ void testExprVectorAdd01() {
const int kVectorCount = 128;
const int kTotalSize = kVectorSize * kVectorCount;

Buffer a_buf(VarHandle("A", kHandle), kFloat, {ExprHandle(kTotalSize)});
Buffer b_buf(VarHandle("B", kHandle), kFloat, {ExprHandle(kTotalSize)});
Buffer c_buf(VarHandle("C", kHandle), kFloat, {ExprHandle(kTotalSize)});
Buffer a_buf(BufHandle("A", {ExprHandle(kTotalSize)}), kFloat);
Buffer b_buf(BufHandle("B", {ExprHandle(kTotalSize)}), kFloat);
Buffer c_buf(BufHandle("C", {ExprHandle(kTotalSize)}), kFloat);

/*
Build the following:
Expand All @@ -197,16 +197,16 @@ void testExprVectorAdd01() {
VarHandle index = VarHandle("index", kInt);
ExprHandle load_a = Load::make(
a_buf,
Ramp::make(index * kVectorSize, 1, kVectorSize),
{Ramp::make(index * kVectorSize, 1, kVectorSize)},
Broadcast::make(1, kVectorSize));
ExprHandle load_b = Load::make(
b_buf,
Ramp::make(index * kVectorSize, 1, kVectorSize),
{Ramp::make(index * kVectorSize, 1, kVectorSize)},
Broadcast::make(1, kVectorSize));
ExprHandle value = load_a + load_b;
Stmt* store_c = Store::make(
c_buf,
Ramp::make(index * kVectorSize, 1, kVectorSize),
{Ramp::make(index * kVectorSize, 1, kVectorSize)},
value,
Broadcast::make(1, kVectorSize));
Stmt* stmt = For::make(index, 0, kVectorCount, store_c);
Expand All @@ -232,9 +232,9 @@ void testExprVectorAdd01() {
void testExprCompareSelectEQ() {
KernelScope kernel_scope;
constexpr int N = 1024;
Buffer a(VarHandle("A", kHandle), kInt, {N});
Buffer b(VarHandle("B", kHandle), kInt, {N});
Buffer c(VarHandle("C", kHandle), kInt, {N});
Buffer a(BufHandle("A", {N}), kInt);
Buffer b(BufHandle("B", {N}), kInt);
Buffer c(BufHandle("C", {N}), kInt);
std::vector<int> a_buffer(N, 1);
std::vector<int> b_buffer(N, 1);
std::vector<int> c_buffer(N, 0);
Expand All @@ -248,10 +248,10 @@ void testExprCompareSelectEQ() {
N,
Store::make(
c,
i,
{i},
CompareSelect::make(
Load::make(a, i, mask),
Load::make(b, i, mask),
Load::make(a, {i}, mask),
Load::make(b, {i}, mask),
CompareSelectOperation::kEQ),
mask));

Expand Down Expand Up @@ -403,11 +403,11 @@ void testExprDynamicShapeAdd() {
KernelScope kernel_scope;
auto testWithSize = [](int32_t size) {
VarHandle n("n", kInt);
Buffer a(VarHandle("a", kHandle), kFloat, {n});
Buffer b(VarHandle("b", kHandle), kFloat, {n});
Buffer c(VarHandle("c", kHandle), kFloat, {n});
Buffer a(BufHandle("a", {n}), kFloat);
Buffer b(BufHandle("b", {n}), kFloat);
Buffer c(BufHandle("c", {n}), kFloat);
VarHandle i("i", kInt);
Stmt* s = For::make(i, 0, n, Store::make(c, i, a(i) + b(i), 1));
Stmt* s = For::make(i, 0, n, Store::make(c, {i}, a(i) + b(i), 1));
std::vector<float> aData(size, 1.0f);
std::vector<float> bData(size, 2.0f);
std::vector<float> cData(size, 0.0f);
Expand All @@ -426,9 +426,9 @@ void testCond01() {
Buffer a_buf("a", kFloat, {N});
VarHandle index = VarHandle("index", kInt);
Stmt* assign_x2 =
Store::make(VarHandle(a_buf.data()), index, cast<float>(index) * 2, 1);
Store::make(BufHandle(a_buf.data()), {index}, cast<float>(index) * 2, 1);
Stmt* assign_x3 =
Store::make(VarHandle(a_buf.data()), index, cast<float>(index) * 3, 1);
Store::make(BufHandle(a_buf.data()), {index}, cast<float>(index) * 3, 1);
ExprHandle even_cond = CompareSelect::make(Mod::make(index, 2), 0, kEQ);
Stmt* assign = Cond::make(even_cond, assign_x2, assign_x3);
Stmt* for_stmt = For::make(index, 0, N, assign);
Expand Down Expand Up @@ -476,7 +476,7 @@ void testStmtClone() {
Buffer a_buf("a", kInt, {N});
VarHandle index = VarHandle("index", kInt);
Stmt* body =
Store::make(VarHandle(a_buf.data()), index, 5, 1);
Store::make(BufHandle(a_buf.data()), {index}, 5, 1);
Stmt* loop = For::make(index, 0, N, body);

Stmt* cloned_loop = Stmt::clone(loop);
Expand All @@ -490,7 +490,7 @@ void testStmtClone() {

// Let's add another assign to the body in the cloned loop and verify that the
// original statement hasn't changed while the cloned one has.
Stmt* body_addition = Store::make(VarHandle(a_buf.data()), index, 33, 1);
Stmt* body_addition = Store::make(BufHandle(a_buf.data()), {index}, 33, 1);
Block* cloned_body =
static_cast<Block*>(static_cast<const For*>(cloned_loop)->body());
cloned_body->append_stmt(body_addition);
Expand Down
Loading

0 comments on commit 3ef5ff6

Please sign in to comment.