Skip to content

Commit

Permalink
fix crash when eletwise inputs are different rank
Browse files Browse the repository at this point in the history
when two INPUT are different rank, AlignPermuteVectorForElementWise()
will force align them and crash

Type: Bug fix

Signed-off-by: Chen <[email protected]>
  • Loading branch information
Chen committed Dec 6, 2023
1 parent 5173979 commit 2bc8d76
Show file tree
Hide file tree
Showing 2 changed files with 114 additions and 0 deletions.
62 changes: 62 additions & 0 deletions src/tim/transform/layout_inference_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -420,4 +420,66 @@ TEST(RoiAlign, nhwc) {
std::vector<float> output(golden.size());
EXPECT_TRUE(infer_output->CopyDataFromTensor(output.data()));
EXPECT_TRUE(ArraysMatch(golden, output, 1e-5f));
}

TEST(Eletwise, _) {
auto ctx = tim::vx::Context::Create();
auto src_graph = ctx->CreateGraph();

tim::vx::ShapeType shape_1_2({1, 2});
tim::vx::ShapeType shape_1_2_3({1, 2, 3});
tim::vx::ShapeType shape_1_2_3_4({1, 2, 3, 4});

tim::vx::TensorSpec input0_spec(tim::vx::DataType::FLOAT32, shape_1_2_3,
tim::vx::TensorAttribute::INPUT);
tim::vx::TensorSpec input1_spec(tim::vx::DataType::FLOAT32, shape_1_2,
tim::vx::TensorAttribute::INPUT);
tim::vx::TensorSpec input2_spec(tim::vx::DataType::FLOAT32, shape_1_2_3_4,
tim::vx::TensorAttribute::INPUT);
tim::vx::TensorSpec transient_spec(tim::vx::DataType::FLOAT32, {0,0,0},
tim::vx::TensorAttribute::TRANSIENT);
tim::vx::TensorSpec output_spec(tim::vx::DataType::FLOAT32, shape_1_2_3_4,
tim::vx::TensorAttribute::OUTPUT);

std::vector<float> input0_data = {1, 1};
std::vector<float> input1_data = {1, 1, 1, 1, 1, 1};
std::vector<float> input2_data = {
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
};
std::vector<float> golden = {
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
};

auto input_0 = src_graph->CreateTensor(input0_spec);
auto input_1 = src_graph->CreateTensor(input1_spec);
auto input_2 = src_graph->CreateTensor(input2_spec);
auto transient_0 = src_graph->CreateTensor(transient_spec);
auto transient_1 = src_graph->CreateTensor(transient_spec);
auto output_t = src_graph->CreateTensor(output_spec);

auto add = src_graph->CreateOperation<tim::vx::ops::Add>();
auto multiply = src_graph->CreateOperation<tim::vx::ops::Multiply>();
auto sub = src_graph->CreateOperation<tim::vx::ops::Sub>();
(*add).BindInput(input_0).BindInput(input_1).BindOutput(transient_0);
(*multiply).BindInput(input_1).BindInput(input_2).BindOutput(transient_1);
(*sub).BindInput(transient_0).BindInput(transient_1).BindOutput(output_t);
// Do layout inference
auto transform = tim::transform::LayoutInference(src_graph, ctx);
auto infer_graph = transform.first;
auto graph_io_map = transform.second;
infer_graph->Compile();

auto infer_input0 = graph_io_map[src_graph->InputsTensor()[0]];
auto infer_input1 = graph_io_map[src_graph->InputsTensor()[1]];
auto infer_input2 = graph_io_map[src_graph->InputsTensor()[2]];
auto infer_output = graph_io_map[src_graph->OutputsTensor()[0]];

infer_input1->CopyDataToTensor(input0_data.data(), input0_data.size() * sizeof(float));
infer_input0->CopyDataToTensor(input1_data.data(), input1_data.size() * sizeof(float));
infer_input2->CopyDataToTensor(input2_data.data(), input2_data.size() * sizeof(float));
infer_graph->Run();

std::vector<float> output(golden.size() * 4);
EXPECT_TRUE(infer_output->CopyDataFromTensor(output.data()));
EXPECT_TRUE(ArraysMatch(golden, output, 1e-5f));
}
52 changes: 52 additions & 0 deletions src/tim/transform/ops/elementwise_layout_inference.h
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,32 @@ class ElementWiseLayoutInfer : public OpLayoutInfer {

void OnInputs(
std::vector<std::shared_ptr<vx::Tensor>>& next_tensors) override {
auto in_0 = op_->impl()->InputsTensor()[0];
auto in_1 = op_->impl()->InputsTensor()[1];
std::shared_ptr<tim::vx::Tensor> short_tensor =
in_0->GetShape().size() > in_1->GetShape().size() ? in_1 : in_0;
std::shared_ptr<tim::vx::Tensor> long_tensor =
in_0->GetShape().size() < in_1->GetShape().size() ? in_1 : in_0;
if (in_0->GetSpec().attr_ != tim::vx::CONSTANT &&
in_1->GetSpec().attr_ != tim::vx::CONSTANT &&
in_0->GetShape().size() != in_1->GetShape().size()) {
auto pv_long = context_->GetPermuteVector(long_tensor);
auto pv_short = context_->GetPermuteVector(short_tensor);
auto rank_long = pv_long->Rank();
auto rank_short = pv_short->Rank();
auto expanded_pv = MakeShared(rank_long);
// if different size, expand short pv to long pv
for (uint32_t i = 0; i < rank_short; ++i) {
expanded_pv->At(i) = pv_short->At(i); // replace low dims with short pv
}
std::vector<uint32_t> expanded_shape(short_tensor->GetShape());
for (uint32_t i = 0; i < rank_long; ++i) {
if (i >= rank_short) expanded_shape.push_back(1);
}
short_tensor->GetSpec().SetShape(expanded_shape);

context_->SetPermuteVector(short_tensor, expanded_pv); // set new expand pv
}
auto required_pv = AlignPermuteVectorForElementWise();
auto elementwise = context_->infer_graph_->CreateOperation<OpType>();
for (const auto& i_src : op_->impl()->InputsTensor()) {
Expand All @@ -63,6 +89,32 @@ class MultiplyLayoutInfer : public OpLayoutInfer {

void OnInputs(
std::vector<std::shared_ptr<vx::Tensor>>& next_tensors) override {
auto in_0 = op_->impl()->InputsTensor()[0];
auto in_1 = op_->impl()->InputsTensor()[1];
std::shared_ptr<tim::vx::Tensor> short_tensor =
in_0->GetShape().size() > in_1->GetShape().size() ? in_1 : in_0;
std::shared_ptr<tim::vx::Tensor> long_tensor =
in_0->GetShape().size() < in_1->GetShape().size() ? in_1 : in_0;
if (in_0->GetSpec().attr_ != tim::vx::CONSTANT &&
in_1->GetSpec().attr_ != tim::vx::CONSTANT &&
in_0->GetShape().size() != in_1->GetShape().size()) {
auto pv_long = context_->GetPermuteVector(long_tensor);
auto pv_short = context_->GetPermuteVector(short_tensor);
auto rank_long = pv_long->Rank();
auto rank_short = pv_short->Rank();
auto expanded_pv = MakeShared(rank_long);
// if different size, expand short pv to long pv
for (uint32_t i = 0; i < rank_short; ++i) {
expanded_pv->At(i) = pv_short->At(i); // replace low dims with short pv
}
std::vector<uint32_t> expanded_shape(short_tensor->GetShape());
for (uint32_t i = 0; i < rank_long; ++i) {
if (i >= rank_short) expanded_shape.push_back(1);
}
short_tensor->GetSpec().SetShape(expanded_shape);

context_->SetPermuteVector(short_tensor, expanded_pv); // set new expand pv
}
auto required_pv = AlignPermuteVectorForElementWise();
auto multiply =
context_->infer_graph_->CreateOperation<tim::vx::ops::Multiply>(
Expand Down

0 comments on commit 2bc8d76

Please sign in to comment.