Skip to content

Commit

Permalink
fix crash when eletwise inputs are different rank
Browse files Browse the repository at this point in the history
when two INPUT are different rank, AlignPermuteVectorForElementWise()
will force align them and crash

Type: Bug fix

Signed-off-by: Chen <[email protected]>
  • Loading branch information
Chen committed Dec 4, 2023
1 parent 5173979 commit a57aaff
Showing 1 changed file with 48 additions and 0 deletions.
48 changes: 48 additions & 0 deletions src/tim/transform/ops/elementwise_layout_inference.h
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,30 @@ class ElementWiseLayoutInfer : public OpLayoutInfer {

void OnInputs(
std::vector<std::shared_ptr<vx::Tensor>>& next_tensors) override {
auto in_0 = op_->impl()->InputsTensor()[0];
auto in_1 = op_->impl()->InputsTensor()[1];
std::shared_ptr<tim::vx::Tensor> short_tensor =
in_0->GetShape().size() > in_1->GetShape().size() ? in_1 : in_0;
std::shared_ptr<tim::vx::Tensor> long_tensor =
in_0->GetShape().size() < in_1->GetShape().size() ? in_1 : in_0;
if (in_0->GetSpec().attr_ == tim::vx::INPUT &&
in_1->GetSpec().attr_ == tim::vx::INPUT &&
in_0->GetShape().size() != in_1->GetShape().size()) {
auto pv_long = context_->GetPermuteVector(long_tensor);
auto pv_short = context_->GetPermuteVector(short_tensor);
auto size_long = pv_long->Rank();
auto size_short = pv_short->Rank();
auto expand_pv = MakeShared(size_long);
// if different size, expand short pv as long pv
for (uint32_t i = 0; i < size_short; ++i) {
expand_pv->At(i) = pv_short->At(i); // replace low dims with short pv
}
auto expanded_shape =
GetExpandedShape(long_tensor->GetShape(), short_tensor->GetShape());
short_tensor->GetSpec().SetShape(expanded_shape);

context_->SetPermuteVector(short_tensor, expand_pv); // set new expand pv
}
auto required_pv = AlignPermuteVectorForElementWise();
auto elementwise = context_->infer_graph_->CreateOperation<OpType>();
for (const auto& i_src : op_->impl()->InputsTensor()) {
Expand All @@ -63,6 +87,30 @@ class MultiplyLayoutInfer : public OpLayoutInfer {

void OnInputs(
std::vector<std::shared_ptr<vx::Tensor>>& next_tensors) override {
auto in_0 = op_->impl()->InputsTensor()[0];
auto in_1 = op_->impl()->InputsTensor()[1];
std::shared_ptr<tim::vx::Tensor> short_tensor =
in_0->GetShape().size() > in_1->GetShape().size() ? in_1 : in_0;
std::shared_ptr<tim::vx::Tensor> long_tensor =
in_0->GetShape().size() < in_1->GetShape().size() ? in_1 : in_0;
if (in_0->GetSpec().attr_ == tim::vx::INPUT &&
in_1->GetSpec().attr_ == tim::vx::INPUT &&
in_0->GetShape().size() != in_1->GetShape().size()) {
auto pv_long = context_->GetPermuteVector(long_tensor);
auto pv_short = context_->GetPermuteVector(short_tensor);
auto size_long = pv_long->Rank();
auto size_short = pv_short->Rank();
auto expand_pv = MakeShared(size_long);
// if different size, expand short pv to long pv
for (uint32_t i = 0; i < size_short; ++i) {
expand_pv->At(i) = pv_short->At(i); // replace low dims with short pv
}
auto expanded_shape =
GetExpandedShape(long_tensor->GetShape(), short_tensor->GetShape());
short_tensor->GetSpec().SetShape(expanded_shape);

context_->SetPermuteVector(short_tensor, expand_pv); // set new expand pv
}
auto required_pv = AlignPermuteVectorForElementWise();
auto multiply =
context_->infer_graph_->CreateOperation<tim::vx::ops::Multiply>(
Expand Down

0 comments on commit a57aaff

Please sign in to comment.