From a57aaff5ea0e1609d447129dba81a794c20c6029 Mon Sep 17 00:00:00 2001 From: Chen Date: Mon, 4 Dec 2023 03:31:51 +0000 Subject: [PATCH] fix crash when eletwise inputs are different rank when two INPUT are different rank, AlignPermuteVectorForElementWise() will force align them and crash Type: Bug fix Signed-off-by: Chen --- .../ops/elementwise_layout_inference.h | 48 +++++++++++++++++++ 1 file changed, 48 insertions(+) diff --git a/src/tim/transform/ops/elementwise_layout_inference.h b/src/tim/transform/ops/elementwise_layout_inference.h index 102609d5d..30ac0557c 100644 --- a/src/tim/transform/ops/elementwise_layout_inference.h +++ b/src/tim/transform/ops/elementwise_layout_inference.h @@ -42,6 +42,30 @@ class ElementWiseLayoutInfer : public OpLayoutInfer { void OnInputs( std::vector>& next_tensors) override { + auto in_0 = op_->impl()->InputsTensor()[0]; + auto in_1 = op_->impl()->InputsTensor()[1]; + std::shared_ptr short_tensor = + in_0->GetShape().size() > in_1->GetShape().size() ? in_1 : in_0; + std::shared_ptr long_tensor = + in_0->GetShape().size() < in_1->GetShape().size() ? in_1 : in_0; + if (in_0->GetSpec().attr_ == tim::vx::INPUT && + in_1->GetSpec().attr_ == tim::vx::INPUT && + in_0->GetShape().size() != in_1->GetShape().size()) { + auto pv_long = context_->GetPermuteVector(long_tensor); + auto pv_short = context_->GetPermuteVector(short_tensor); + auto size_long = pv_long->Rank(); + auto size_short = pv_short->Rank(); + auto expand_pv = MakeShared(size_long); + // if different size, expand short pv as long pv + for (uint32_t i = 0; i < size_short; ++i) { + expand_pv->At(i) = pv_short->At(i); // replace low dims with short pv + } + auto expanded_shape = + GetExpandedShape(long_tensor->GetShape(), short_tensor->GetShape()); + short_tensor->GetSpec().SetShape(expanded_shape); + + context_->SetPermuteVector(short_tensor, expand_pv); // set new expand pv + } auto required_pv = AlignPermuteVectorForElementWise(); auto elementwise = context_->infer_graph_->CreateOperation(); for (const auto& i_src : op_->impl()->InputsTensor()) { @@ -63,6 +87,30 @@ class MultiplyLayoutInfer : public OpLayoutInfer { void OnInputs( std::vector>& next_tensors) override { + auto in_0 = op_->impl()->InputsTensor()[0]; + auto in_1 = op_->impl()->InputsTensor()[1]; + std::shared_ptr short_tensor = + in_0->GetShape().size() > in_1->GetShape().size() ? in_1 : in_0; + std::shared_ptr long_tensor = + in_0->GetShape().size() < in_1->GetShape().size() ? in_1 : in_0; + if (in_0->GetSpec().attr_ == tim::vx::INPUT && + in_1->GetSpec().attr_ == tim::vx::INPUT && + in_0->GetShape().size() != in_1->GetShape().size()) { + auto pv_long = context_->GetPermuteVector(long_tensor); + auto pv_short = context_->GetPermuteVector(short_tensor); + auto size_long = pv_long->Rank(); + auto size_short = pv_short->Rank(); + auto expand_pv = MakeShared(size_long); + // if different size, expand short pv to long pv + for (uint32_t i = 0; i < size_short; ++i) { + expand_pv->At(i) = pv_short->At(i); // replace low dims with short pv + } + auto expanded_shape = + GetExpandedShape(long_tensor->GetShape(), short_tensor->GetShape()); + short_tensor->GetSpec().SetShape(expanded_shape); + + context_->SetPermuteVector(short_tensor, expand_pv); // set new expand pv + } auto required_pv = AlignPermuteVectorForElementWise(); auto multiply = context_->infer_graph_->CreateOperation(