Skip to content

Commit a68823d

Browse files
authored
[Fix] Remove CastRemover for Tensor Indices (issue cornell-zhang#386) (cornell-zhang#453)
* [Test] Add test case for issue cornell-zhang#386 * [Fix] Remove CastRemover for tensor indices and make sure index is i32 type in LLVM backend (Fix cornell-zhang#386) * [Fix] Fix issue with simplying index expression in generate_reuse_buffer * [Fix] Add type checking in CastRemover * [Fix] Move CastRemover to ir_util.h
1 parent abb2f0e commit a68823d

File tree

5 files changed

+160
-7
lines changed

5 files changed

+160
-7
lines changed

python/heterocl/tensor.py

-5
Original file line numberDiff line numberDiff line change
@@ -148,7 +148,6 @@ def __setitem__(self, indices, expr):
148148
if not isinstance(indices, tuple):
149149
indices = (indices,)
150150
indices = self.indices + indices
151-
indices = util.CastRemover().mutate(indices)
152151
index, bit, _ = util.get_index(self.tensor.shape, indices, 0)
153152
if not Stage.get_len():
154153
raise TensorError("Cannot set tensor elements without compute APIs")
@@ -238,7 +237,6 @@ def shape(self):
238237
def asnode(self):
239238
if len(self.indices) < len(self.tensor.shape):
240239
raise TensorError("Accessing a slice of tensor is not allowed")
241-
self.indices = util.CastRemover().mutate(self.indices)
242240
index, bit, _ = util.get_index(self.tensor.shape, self.indices, 0)
243241
if bit is None:
244242
return _make.Load(self._dtype, self.tensor.buf.data, index)
@@ -343,20 +341,17 @@ def __repr__(self):
343341
return "Tensor('" + self.name + "', " + str(self.shape) + ", " + str(self.dtype) + ")"
344342

345343
def __getitem__(self, indices):
346-
indices = util.CastRemover().mutate(indices)
347344
if Stage.get_len():
348345
Stage.get_current().input_stages.add(self.last_update)
349346
if not isinstance(indices, tuple):
350347
indices = (indices,)
351348
return TensorSlice(self, indices)
352349

353350
def __setitem__(self, indices, expr):
354-
indices = util.CastRemover().mutate(indices)
355351
Stage.get_current().input_stages.add(self.last_update)
356352
Stage.get_current().lhs_tensors.add(self)
357353
if not isinstance(indices, tuple):
358354
indices = (indices,)
359-
indices = util.CastRemover().mutate(indices)
360355
if len(indices) < len(self.shape):
361356
raise TensorError("Accessing a slice of tensor is not allowed")
362357
else:

tests/issues/test_issue_386.py

+23
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
import heterocl as hcl
2+
3+
def test_cast_removal():
4+
hcl.init()
5+
6+
A = hcl.placeholder((10,10), dtype=hcl.UInt(16), name="A")
7+
B = hcl.placeholder((10,10), dtype=hcl.Int(16), name="B")
8+
9+
def algo(A, B):
10+
def f_mutate(i,j):
11+
factor = hcl.scalar(B[0][0][13:11], name="factor")
12+
idx = hcl.scalar(B[0][0][11:0], dtype=hcl.UInt(16), name="idx")
13+
idx += i * hcl.cast(hcl.UInt(16), factor.v)
14+
A[idx][j] = B[idx][j]
15+
bound = hcl.scalar(5, dtype=hcl.Int(32))
16+
domain = (hcl.cast(hcl.UInt(32), bound.v), hcl.cast(hcl.UInt(32), bound.v))
17+
hcl.mutate(domain, f_mutate)
18+
19+
s = hcl.create_schedule([A, B], algo)
20+
f = hcl.build(s, target="vhls")
21+
22+
if __name__ == '__main__':
23+
test_cast_removal()

tvm/src/codegen/llvm/codegen_llvm.cc

+9
Original file line numberDiff line numberDiff line change
@@ -972,7 +972,11 @@ llvm::Value* CodeGenLLVM::VisitExpr_(const Load* op) {
972972
Type t = op->type;
973973
bool is_volatile = volatile_buf_.count(op->buffer_var.get());
974974
llvm::Value* buffer = MakeValue(op->buffer_var);
975+
// Make sure index is int32 type since we removed CastRemover pass.
976+
// check op->index's datatype, if it is not int32, cast it to int32
975977
llvm::Value* index = MakeValue(op->index);
978+
if (op->index.type() != Int(32))
979+
index = CreateCast(op->index.type(), Int(32), index);
976980

977981
if (t.lanes() == 1) {
978982
int alignment, native_bits;
@@ -1257,6 +1261,11 @@ void CodeGenLLVM::VisitStmt_(const Store* op) {
12571261
llvm::Value* index = MakeValue(op->index);
12581262
llvm::Value* value = MakeValue(op->value);
12591263

1264+
// Make sure index is int32 type since we removed CastRemover pass.
1265+
// check op->index's datatype, if it is not int32, cast it to int32
1266+
if (op->index.type() != Int(32))
1267+
index = CreateCast(op->index.type(), Int(32), index);
1268+
12601269
if (t.lanes() == 1) {
12611270
int alignment, native_bits;
12621271
GetAlignment(t, op->buffer_var.get(), op->index, &alignment, &native_bits);

tvm/src/pass/generate_reuse_buffer.cc

+9-2
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
#include <tvm/ir_pass.h>
99
#include <tvm/ir_visitor.h>
1010
#include <tvm/operation.h>
11+
#include "ir_util.h"
1112

1213
namespace TVM {
1314
namespace ir {
@@ -222,7 +223,8 @@ class ReuseBufferInserter final : public IRMutator {
222223
// check if the bounde is constant
223224
// e.g. x+r => diff_expr = 10
224225
// e.g. y+c => diff_expr = 3
225-
Expr diff_expr = Simplify(max_expr - min_expr + 1);
226+
CastRemover castRemover;
227+
Expr diff_expr = Simplify(castRemover.Mutate(max_expr - min_expr) + 1);
226228
if (!is_const(diff_expr)) // e.g. y*(y+c) would be illegal
227229
LOG(FATAL) << "Irregular access pattern is not yet supported";
228230
// check if the specified axis is reused by running the next iteration
@@ -231,7 +233,7 @@ class ReuseBufferInserter final : public IRMutator {
231233
// first check if the axis is the specified reuse axis
232234
// e.g. y => y+1
233235
Expr next_min = substitute(next_subst, min_expr);
234-
Expr next_diff = Simplify(next_min - min_expr);
236+
Expr next_diff = Simplify(castRemover.Mutate(next_min - min_expr));
235237
if (!is_const(next_diff)) // e.g. y*y+c would be illegal
236238
LOG(FATAL) << "Irregular access pattern is not yet supported";
237239
// then check if we there is reuse in this axis
@@ -262,7 +264,10 @@ class ReuseBufferInserter final : public IRMutator {
262264
std::vector<VarExpr> reuse_loop_vars;
263265
for (size_t dim = 0; dim < ndim; dim++) {
264266
Expr index = min_list[dim];
267+
CastRemover castRemover;
268+
index = castRemover.Mutate(index);
265269
Expr reuse_index = Simplify(substitute(null_axis_subst_, index));
270+
266271
// create a new variable if the shape is not one
267272
if (!is_one(reuse_shape[dim])) {
268273
// TODO(Sean): fix the name
@@ -280,7 +285,9 @@ class ReuseBufferInserter final : public IRMutator {
280285
}
281286
}
282287
}
288+
283289
Expr rhs = substitute(reuse_index, new_loop_var, index);
290+
284291
// special case when the reuse index is 0
285292
if (is_zero(reuse_index) && dim == static_cast<size_t>(reuse))
286293
rhs = rhs + new_loop_var;

tvm/src/pass/ir_util.h

+119
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
#include <tvm/ir.h>
1010
#include <tvm/runtime/device_api.h>
11+
#include <tvm/ir_mutator.h>
1112
#include <vector>
1213

1314
namespace TVM {
@@ -154,6 +155,124 @@ inline int GetTempAllocaAlignment(Type type, int32_t const_size) {
154155
}
155156
return align;
156157
}
158+
159+
// Remove cast in binary expressions
160+
// Example: (cast(x) + cast(y)) -> (x + y)
161+
// Usage: CastRemover castRemover;
162+
// Expr expr = castRemover.Mutate(expr);
163+
class CastRemover final : public IRMutator {
164+
public:
165+
CastRemover() {}
166+
167+
Expr Mutate_(const Cast* op, const Expr& e) {
168+
return op->value;
169+
}
170+
171+
Expr Mutate_(const Add* op, const Expr& e) {
172+
Expr a = this->Mutate(op->a);
173+
Expr b = this->Mutate(op->b);
174+
if (const Cast* ca = a.as<Cast>()) {
175+
a = ca->value;
176+
}
177+
if (const Cast* cb = b.as<Cast>()) {
178+
b = cb->value;
179+
}
180+
if (a.type() != b.type())
181+
LOG(FATAL) << "CastRemover: type mismatch "
182+
<< a.type() << " vs " << b.type();
183+
return Add::make(a, b);
184+
}
185+
186+
Expr Mutate_(const Sub* op, const Expr& e) {
187+
Expr a = this->Mutate(op->a);
188+
Expr b = this->Mutate(op->b);
189+
if (const Cast* ca = a.as<Cast>()) {
190+
a = ca->value;
191+
}
192+
if (const Cast* cb = b.as<Cast>()) {
193+
b = cb->value;
194+
}
195+
if (a.type() != b.type())
196+
LOG(FATAL) << "CastRemover: type mismatch "
197+
<< a.type() << " vs " << b.type();
198+
return Sub::make(a, b);
199+
}
200+
201+
Expr Mutate_(const Mul* op, const Expr& e) {
202+
Expr a = this->Mutate(op->a);
203+
Expr b = this->Mutate(op->b);
204+
if (const Cast* ca = a.as<Cast>()) {
205+
a = ca->value;
206+
}
207+
if (const Cast* cb = b.as<Cast>()) {
208+
b = cb->value;
209+
}
210+
if (a.type() != b.type())
211+
LOG(FATAL) << "CastRemover: type mismatch "
212+
<< a.type() << " vs " << b.type();
213+
return Mul::make(a, b);
214+
}
215+
216+
Expr Mutate_(const Div* op, const Expr& e) {
217+
Expr a = this->Mutate(op->a);
218+
Expr b = this->Mutate(op->b);
219+
if (const Cast* ca = a.as<Cast>()) {
220+
a = ca->value;
221+
}
222+
if (const Cast* cb = b.as<Cast>()) {
223+
b = cb->value;
224+
}
225+
if (a.type() != b.type())
226+
LOG(FATAL) << "CastRemover: type mismatch "
227+
<< a.type() << " vs " << b.type();
228+
return Div::make(a, b);
229+
}
230+
231+
Expr Mutate_(const Mod* op, const Expr& e) {
232+
Expr a = this->Mutate(op->a);
233+
Expr b = this->Mutate(op->b);
234+
if (const Cast* ca = a.as<Cast>()) {
235+
a = ca->value;
236+
}
237+
if (const Cast* cb = b.as<Cast>()) {
238+
b = cb->value;
239+
}
240+
if (a.type() != b.type())
241+
LOG(FATAL) << "CastRemover: type mismatch "
242+
<< a.type() << " vs " << b.type();
243+
return Mod::make(a, b);
244+
}
245+
246+
Expr Mutate_(const Min* op, const Expr& e) {
247+
Expr a = this->Mutate(op->a);
248+
Expr b = this->Mutate(op->b);
249+
if (const Cast* ca = a.as<Cast>()) {
250+
a = ca->value;
251+
}
252+
if (const Cast* cb = b.as<Cast>()) {
253+
b = cb->value;
254+
}
255+
if (a.type() != b.type())
256+
LOG(FATAL) << "CastRemover: type mismatch "
257+
<< a.type() << " vs " << b.type();
258+
return Min::make(a, b);
259+
}
260+
261+
Expr Mutate_(const Max* op, const Expr& e) {
262+
Expr a = this->Mutate(op->a);
263+
Expr b = this->Mutate(op->b);
264+
if (const Cast* ca = a.as<Cast>()) {
265+
a = ca->value;
266+
}
267+
if (const Cast* cb = b.as<Cast>()) {
268+
b = cb->value;
269+
}
270+
if (a.type() != b.type())
271+
LOG(FATAL) << "CastRemover: type mismatch "
272+
<< a.type() << " vs " << b.type();
273+
return Max::make(a, b);
274+
}
275+
};
157276
} // namespace ir
158277
} // namespace TVM
159278
#endif // PASS_IR_UTIL_H_

0 commit comments

Comments
 (0)