Skip to content

Commit

Permalink
fix layoutinfer crash when logical op inputs are different rank
Browse files Browse the repository at this point in the history
Type: Bug fix

Signed-off-by: Chen <[email protected]>
  • Loading branch information
Chen committed Dec 12, 2023
1 parent 0dc7a34 commit 3f2291b
Showing 1 changed file with 27 additions and 0 deletions.
27 changes: 27 additions & 0 deletions src/tim/transform/ops/logical_layout_inference.h
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,33 @@ class LogicalOpsLayoutInfer : public OpLayoutInfer {

void OnInputs(
std::vector<std::shared_ptr<vx::Tensor>>& next_tensors) override {
auto in_0 = op_->impl()->InputsTensor()[0];
auto in_1 = op_->impl()->InputsTensor()[1];
std::shared_ptr<tim::vx::Tensor> short_tensor =
in_0->GetShape().size() > in_1->GetShape().size() ? in_1 : in_0;
std::shared_ptr<tim::vx::Tensor> long_tensor =
in_0->GetShape().size() < in_1->GetShape().size() ? in_1 : in_0;
if (in_0->GetSpec().attr_ != tim::vx::CONSTANT &&
in_1->GetSpec().attr_ != tim::vx::CONSTANT &&
in_0->GetShape().size() != in_1->GetShape().size()) {
auto pv_long = context_->GetPermuteVector(long_tensor);
auto pv_short = context_->GetPermuteVector(short_tensor);
auto rank_long = pv_long->Rank();
auto rank_short = pv_short->Rank();
auto expanded_pv = MakeShared(rank_long);
// if different size, expand short pv to long pv
for (uint32_t i = 0; i < rank_short; ++i) {
expanded_pv->At(i) = pv_short->At(i); // replace low dims with short pv
}
std::vector<uint32_t> expanded_shape(short_tensor->GetShape());
for (uint32_t i = 0; i < rank_long; ++i) {
if (i >= rank_short) expanded_shape.push_back(1);
}
short_tensor->GetSpec().SetShape(expanded_shape);

context_->SetPermuteVector(short_tensor,
expanded_pv); // set new expand pv
}
auto required_pv = AlignPermuteVectorForMutilInputs();
auto infer_out = CreateOutputsTensor(required_pv);
auto logical_op = context_->infer_graph_->CreateOperation<OpTpye>();
Expand Down

0 comments on commit 3f2291b

Please sign in to comment.