Skip to content

Commit

Permalink
fix crash when eletwise inputs are different rank
Browse files Browse the repository at this point in the history
when two INPUT are different rank, AlignPermuteVectorForElementWise()
will force align them and crash

Type: Bug fix

Signed-off-by: Chen <[email protected]>
  • Loading branch information
Chen committed Dec 4, 2023
1 parent 5173979 commit 9cbcaa1
Showing 1 changed file with 27 additions and 0 deletions.
27 changes: 27 additions & 0 deletions src/tim/transform/ops/elementwise_layout_inference.h
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,33 @@ class ElementWiseLayoutInfer : public OpLayoutInfer {

void OnInputs(
std::vector<std::shared_ptr<vx::Tensor>>& next_tensors) override {
auto in_0 = op_->impl()->InputsTensor()[0];
auto in_1 = op_->impl()->InputsTensor()[1];
std::shared_ptr<tim::vx::Tensor> short_tensor =
in_0->GetShape().size() > in_1->GetShape().size() ? in_1 : in_0;
std::shared_ptr<tim::vx::Tensor> long_tensor =
in_0->GetShape().size() < in_1->GetShape().size() ? in_1 : in_0;
if (in_0->GetSpec().attr_ == tim::vx::INPUT &&
in_1->GetSpec().attr_ == tim::vx::INPUT &&
in_0->GetShape().size() != in_1->GetShape().size()) {
auto pv_long = context_->GetPermuteVector(long_tensor);
auto pv_short = context_->GetPermuteVector(short_tensor);
auto size_long = pv_long->Rank();
auto size_short = pv_short->Rank();
auto expand_pv = MakeShared(size_long);
// if different size, expand short pv as long pv
for (uint32_t i = 0; i < size_short; ++i) {
expand_pv->At(i) = pv_short->At(i); // replace low dims with short pv
}

auto short_shape = short_tensor->GetShape();
for(uint32_t i = 0;i<size_long;++i) { // expand shape and set new tensor shape
if(i >= size_short) short_shape.push_back(1);
}
short_tensor->GetSpec().SetShape(short_shape);

context_->SetPermuteVector(short_tensor, expand_pv); // set new expand pv
}
auto required_pv = AlignPermuteVectorForElementWise();
auto elementwise = context_->infer_graph_->CreateOperation<OpType>();
for (const auto& i_src : op_->impl()->InputsTensor()) {
Expand Down

0 comments on commit 9cbcaa1

Please sign in to comment.