Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refine layout inference #671

Merged
merged 20 commits into from
Dec 20, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 11 additions & 5 deletions src/tim/transform/layout_infer_context.h
Original file line number Diff line number Diff line change
@@ -1,16 +1,18 @@
#ifndef TIM_VX_LAYOUT_INFER_CONTEXT_H_
#define TIM_VX_LAYOUT_INFER_CONTEXT_H_

#include "permute_vector.h"
#include "tim/transform/layout_inference.h"

#include <unordered_map>

namespace tim {
namespace transform {
namespace layout_inference_impl {
class LayoutInferContext {
public:
LayoutInferContext(const std::shared_ptr<vx::Graph>& src_graph,
std::shared_ptr<vx::Graph>& infer_graph)
: src_graph_(src_graph), infer_graph_(infer_graph) {}
std::shared_ptr<vx::Graph>& infer_graph);
void SetPermuteVector(std::shared_ptr<vx::Tensor> tensor,
std::shared_ptr<IPermuteVector> pv);
const std::shared_ptr<IPermuteVector> GetPermuteVector(
Expand All @@ -20,14 +22,18 @@ class LayoutInferContext {
bool IsReadyForInfer(const std::shared_ptr<vx::Operation>& op) const;
void UpdateTensorMap(const std::shared_ptr<vx::Tensor>& t_src,
const std::shared_ptr<vx::Tensor>& t_layout);
std::shared_ptr<vx::Tensor> GetMapedTensor(
std::shared_ptr<vx::Tensor> GetMappedTensor(
const std::shared_ptr<vx::Tensor>& t_src) const;
std::shared_ptr<vx::Tensor> GetMappedGraphInputTensor(
const std::shared_ptr<vx::Tensor>& t_src) const;
std::shared_ptr<vx::Tensor> GetMappedGraphOutputTensor(
const std::shared_ptr<vx::Tensor>& t_src) const;

void UpdateGraphInputMap(const std::shared_ptr<vx::Tensor>& i_src,
const std::shared_ptr<vx::Tensor>& i_layout);

void UpdateGraphOutputMap(const std::shared_ptr<vx::Tensor>& o_src,
const std::shared_ptr<vx::Tensor>& o_layout);
const std::shared_ptr<vx::Tensor>& o_layout);

std::map<std::shared_ptr<vx::Tensor>, std::shared_ptr<vx::Tensor>>
GetGraphInputMap() const {
Expand All @@ -44,7 +50,7 @@ class LayoutInferContext {
private:
std::map<std::shared_ptr<vx::Tensor>, std::shared_ptr<IPermuteVector>>
tensor_pv_;
std::vector<std::shared_ptr<vx::Operation>> visited_op_;
std::unordered_map<std::shared_ptr<vx::Operation>, bool> op_visited_;
// tensor_in_src -> tensor_in_layout
std::map<std::shared_ptr<vx::Tensor>, std::shared_ptr<vx::Tensor>>
tensor_map_;
Expand Down
314 changes: 176 additions & 138 deletions src/tim/transform/layout_inference.cc

Large diffs are not rendered by default.

8 changes: 4 additions & 4 deletions src/tim/transform/ops/activation_layout_inference.h
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ class ActivationLayoutInfer : public OpLayoutInfer {
auto activation = op_->Clone(context_->infer_graph_);
auto out_infer = CreateOutputsTensor(input_pv);
(*activation)
.BindInput(context_->GetMapedTensor(i_src))
.BindInput(context_->GetMappedTensor(i_src))
.BindOutput(out_infer[0]);
context_->SetPermuteVector(op_->impl()->OutputsTensor()[0], input_pv);
next_tensors.push_back(op_->impl()->OutputsTensor()[0]);
Expand Down Expand Up @@ -107,7 +107,7 @@ class PReluLayoutInfer : public OpLayoutInfer {
context_->infer_graph_->CreateOperation<vx::ops::Reshape>(
boardcast_shape);
(*reshape)
.BindInput(context_->GetMapedTensor(src_slope))
.BindInput(context_->GetMappedTensor(src_slope))
.BindOutput(reshape_out);
context_->UpdateTensorMap(src_slope, reshape_out);
}
Expand All @@ -130,8 +130,8 @@ class PReluLayoutInfer : public OpLayoutInfer {
auto out_infer = CreateOutputsTensor(input_pv);

(*prelu)
.BindInput(context_->GetMapedTensor(src_input))
.BindInput(context_->GetMapedTensor(src_slope));
.BindInput(context_->GetMappedTensor(src_input))
.BindInput(context_->GetMappedTensor(src_slope));
(*prelu).BindOutput(out_infer[0]);
context_->SetPermuteVector(op_->impl()->OutputsTensor()[0], input_pv);
next_tensors.push_back(op_->impl()->OutputsTensor()[0]);
Expand Down
2 changes: 1 addition & 1 deletion src/tim/transform/ops/addn_layout_inference.h
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ class AddNLayoutInfer : public OpLayoutInfer {
auto addn = op_->Clone(context_->infer_graph_);

for (const auto& i_src : op_->impl()->InputsTensor()) {
(*addn).BindInput(context_->GetMapedTensor(i_src));
(*addn).BindInput(context_->GetMappedTensor(i_src));
}
auto infer_out = CreateOutputsTensor(required_pv);
(*addn).BindOutput(infer_out[0]);
Expand Down
2 changes: 1 addition & 1 deletion src/tim/transform/ops/arg_layout_inference.h
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ class ArgLayoutInfer : public OpLayoutInfer {

auto arg = op_->Clone(context_->infer_graph_);
auto infer_out = CreateOutputsTensor(input_pv);
(*arg).BindInput(context_->GetMapedTensor(src_input));
(*arg).BindInput(context_->GetMappedTensor(src_input));
(*arg).BindOutput(infer_out[0]);

context_->SetPermuteVector(op_->impl()->OutputsTensor()[0], input_pv);
Expand Down
4 changes: 2 additions & 2 deletions src/tim/transform/ops/batch2space_layout_inference.h
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ class Batch2SpaceLayoutInfer : public OpLayoutInfer {
auto final_pv = pv->Reverse()->Add(required_pv);
if (!final_pv->IsAligned()) {
auto perm_out =
InsertPermute(context_->GetMapedTensor(input_tensors[0]), final_pv);
InsertPermute(context_->GetMappedTensor(input_tensors[0]), final_pv);
context_->UpdateTensorMap(input_tensors[0], perm_out);
context_->SetPermuteVector(input_tensors[0], required_pv);
}
Expand All @@ -70,7 +70,7 @@ class Batch2SpaceLayoutInfer : public OpLayoutInfer {
context_->infer_graph_->CreateOperation<vx::ops::Batch2Space>(
block_size, crop, vx::DataLayout::WHCN);
auto out_tensor_infer = CreateOutputsTensor(required_pv);
(*batch2space).BindInput(context_->GetMapedTensor(input_tensors[0]));
(*batch2space).BindInput(context_->GetMappedTensor(input_tensors[0]));
(*batch2space).BindOutput(out_tensor_infer[0]);
context_->SetPermuteVector(op_->impl()->OutputsTensor()[0], required_pv);
// Add out tensor of src_graph into next_tensor
Expand Down
12 changes: 6 additions & 6 deletions src/tim/transform/ops/batchnorm_layout_inference.h
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ class BatchNormLayoutInfer : public OpLayoutInfer {
perm_out = context_->infer_graph_->CreateTensor(src_in->GetSpec(), (const void*)dataRef.data());
input_pv = MakeShared(src_in->GetShape().size());
} else {
perm_out = context_->GetMapedTensor(src_in);
perm_out = context_->GetMappedTensor(src_in);
input_pv = context_->GetPermuteVector(src_in);
context_->SetPermuteVector(src_in, input_pv);
if (idx == 0) {
Expand All @@ -73,11 +73,11 @@ class BatchNormLayoutInfer : public OpLayoutInfer {

auto batchnorm = op_->Clone(context_->infer_graph_);
auto out_tensor_infer = CreateOutputsTensor(required_pv);
(*batchnorm).BindInput(context_->GetMapedTensor(input_tensors[0]));
(*batchnorm).BindInput(context_->GetMapedTensor(input_tensors[1]));
(*batchnorm).BindInput(context_->GetMapedTensor(input_tensors[2]));
(*batchnorm).BindInput(context_->GetMapedTensor(input_tensors[3]));
(*batchnorm).BindInput(context_->GetMapedTensor(input_tensors[4]));
(*batchnorm).BindInput(context_->GetMappedTensor(input_tensors[0]));
(*batchnorm).BindInput(context_->GetMappedTensor(input_tensors[1]));
(*batchnorm).BindInput(context_->GetMappedTensor(input_tensors[2]));
(*batchnorm).BindInput(context_->GetMappedTensor(input_tensors[3]));
(*batchnorm).BindInput(context_->GetMappedTensor(input_tensors[4]));

(*batchnorm).BindOutput(out_tensor_infer[0]);
context_->SetPermuteVector(op_->impl()->OutputsTensor()[0], required_pv);
Expand Down
2 changes: 1 addition & 1 deletion src/tim/transform/ops/bidirectional_rnn_layout_inference.h
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ class BidirectionalRnnLayoutInfer : public OpLayoutInfer {


for (const auto& i_src : op_->impl()->InputsTensor()) {
(*cloned_op).BindInput(context_->GetMapedTensor(i_src));
(*cloned_op).BindInput(context_->GetMappedTensor(i_src));
}


Expand Down
2 changes: 1 addition & 1 deletion src/tim/transform/ops/broadcast_layout_inference.h
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ class BroadcastLayoutInfer : public OpLayoutInfer {
auto cloned_op = op_->Clone(context_->infer_graph_);

for (const auto& i_src : op_->impl()->InputsTensor()) {
(*cloned_op).BindInput(context_->GetMapedTensor(i_src));
(*cloned_op).BindInput(context_->GetMappedTensor(i_src));
}

std::vector<std::shared_ptr<IPermuteVector>> required_pv_lst;
Expand Down
2 changes: 1 addition & 1 deletion src/tim/transform/ops/concat_layout_inferene.h
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ class ConcatLayoutInfer : public OpLayoutInfer {
auto concat = context_->infer_graph_->CreateOperation<vx::ops::Concat>(
axis, op_->impl()->InputsTensor().size());
for (const auto& i_src : op_->impl()->InputsTensor()) {
(*concat).BindInput(context_->GetMapedTensor(i_src));
(*concat).BindInput(context_->GetMappedTensor(i_src));
}
auto out_infer = CreateOutputsTensor(required_pv);
(*concat).BindOutput(out_infer[0]);
Expand Down
12 changes: 6 additions & 6 deletions src/tim/transform/ops/conv2d_layout_inference.h
Original file line number Diff line number Diff line change
Expand Up @@ -79,10 +79,10 @@ class Conv2dLayoutInfer : public OpLayoutInfer {
auto final_pv = input_pv->Reverse()->Add(required_pv);
if (!final_pv->IsAligned()) {
infer_input =
InsertPermute(context_->GetMapedTensor(input_tensors[0]), final_pv);
InsertPermute(context_->GetMappedTensor(input_tensors[0]), final_pv);
context_->SetPermuteVector(input_tensors[0], required_pv);
} else {
infer_input = context_->GetMapedTensor(input_tensors[0]);
infer_input = context_->GetMappedTensor(input_tensors[0]);
context_->SetPermuteVector(input_tensors[0], input_pv);
}
context_->UpdateTensorMap(input_tensors[0], infer_input);
Expand All @@ -104,10 +104,10 @@ class Conv2dLayoutInfer : public OpLayoutInfer {
auto final_pv = weight_pv->Reverse()->Add(weight_required_pv);
if (!final_pv->IsAligned()) {
infer_weight =
InsertPermute(context_->GetMapedTensor(input_tensors[1]), final_pv);
InsertPermute(context_->GetMappedTensor(input_tensors[1]), final_pv);
context_->SetPermuteVector(input_tensors[1], weight_required_pv);
} else {
infer_weight = context_->GetMapedTensor(input_tensors[1]);
infer_weight = context_->GetMappedTensor(input_tensors[1]);
context_->SetPermuteVector(input_tensors[1], weight_pv);
}
context_->UpdateTensorMap(input_tensors[1], infer_weight);
Expand All @@ -121,7 +121,7 @@ class Conv2dLayoutInfer : public OpLayoutInfer {
infer_bias = context_->infer_graph_->CreateTensor(
input_tensors[2]->GetSpec(), (const void*)dataRef.data());
} else {
infer_bias = context_->GetMapedTensor(input_tensors[2]);
infer_bias = context_->GetMappedTensor(input_tensors[2]);
}
auto bias_pv = MakeShared(1);
context_->UpdateTensorMap(input_tensors[2], infer_bias);
Expand All @@ -131,7 +131,7 @@ class Conv2dLayoutInfer : public OpLayoutInfer {
auto conv2d = op_->Clone(context_->infer_graph_);
auto otensor_infer = CreateOutputsTensor(required_pv);
for (const auto& i_src : input_tensors) {
(*conv2d).BindInput(context_->GetMapedTensor(i_src));
(*conv2d).BindInput(context_->GetMappedTensor(i_src));
}
(*conv2d).BindOutput(otensor_infer[0]);
context_->SetPermuteVector(op_->impl()->OutputsTensor()[0], required_pv);
Expand Down
8 changes: 4 additions & 4 deletions src/tim/transform/ops/conv3d_layout_inference.h
Original file line number Diff line number Diff line change
Expand Up @@ -81,18 +81,18 @@ class Conv3dLayoutInfer : public OpLayoutInfer {
} else {
// For bias
if (in->GetShape().size() == 1) {
infer_tensor = context_->GetMapedTensor(in);
infer_tensor = context_->GetMappedTensor(in);
trans_pv = MakeShared(1);
} else {
// For input/weight
auto pv = context_->GetPermuteVector(in);
auto final_pv = pv->Reverse()->Add(required_pv);
if (!final_pv->IsAligned()) {
infer_tensor =
InsertPermute(context_->GetMapedTensor(in), final_pv);
InsertPermute(context_->GetMappedTensor(in), final_pv);
trans_pv = required_pv;
} else {
infer_tensor = context_->GetMapedTensor(in);
infer_tensor = context_->GetMappedTensor(in);
trans_pv = pv;
}
}
Expand Down Expand Up @@ -131,7 +131,7 @@ class Conv3dLayoutInfer : public OpLayoutInfer {
vx::DataLayout::WHDCN, vx::DataLayout::WHDIcOc);
auto otensor_infer = CreateOutputsTensor(required_pv);
for (const auto& i_src : input_tensors) {
(*conv3d).BindInput(context_->GetMapedTensor(i_src));
(*conv3d).BindInput(context_->GetMappedTensor(i_src));
}
(*conv3d).BindOutput(otensor_infer[0]);
context_->SetPermuteVector(op_->impl()->OutputsTensor()[0], required_pv);
Expand Down
12 changes: 6 additions & 6 deletions src/tim/transform/ops/deconv2d_layout_inference.h
Original file line number Diff line number Diff line change
Expand Up @@ -79,10 +79,10 @@ class DeConv2dLayoutInfer : public OpLayoutInfer {
auto final_pv = input_pv->Reverse()->Add(required_pv);
if (!final_pv->IsAligned()) {
infer_input =
InsertPermute(context_->GetMapedTensor(input_tensors[0]), final_pv);
InsertPermute(context_->GetMappedTensor(input_tensors[0]), final_pv);
context_->SetPermuteVector(input_tensors[0], required_pv);
} else {
infer_input = context_->GetMapedTensor(input_tensors[0]);
infer_input = context_->GetMappedTensor(input_tensors[0]);
context_->SetPermuteVector(input_tensors[0], input_pv);
}
context_->UpdateTensorMap(input_tensors[0], infer_input);
Expand All @@ -104,10 +104,10 @@ class DeConv2dLayoutInfer : public OpLayoutInfer {
auto final_pv = weight_pv->Reverse()->Add(weight_required_pv);
if (!final_pv->IsAligned()) {
infer_weight =
InsertPermute(context_->GetMapedTensor(input_tensors[1]), final_pv);
InsertPermute(context_->GetMappedTensor(input_tensors[1]), final_pv);
context_->SetPermuteVector(input_tensors[1], weight_required_pv);
} else {
infer_weight = context_->GetMapedTensor(input_tensors[1]);
infer_weight = context_->GetMappedTensor(input_tensors[1]);
context_->SetPermuteVector(input_tensors[1], weight_pv);
}
context_->UpdateTensorMap(input_tensors[1], infer_weight);
Expand All @@ -121,7 +121,7 @@ class DeConv2dLayoutInfer : public OpLayoutInfer {
infer_bias = context_->infer_graph_->CreateTensor(
input_tensors[2]->GetSpec(), (const void*)dataRef.data());
} else {
infer_bias = context_->GetMapedTensor(input_tensors[2]);
infer_bias = context_->GetMappedTensor(input_tensors[2]);
}
auto bias_pv = MakeShared(1);
context_->UpdateTensorMap(input_tensors[2], infer_bias);
Expand All @@ -131,7 +131,7 @@ class DeConv2dLayoutInfer : public OpLayoutInfer {
auto deconv = op_->Clone(context_->infer_graph_);
auto infer_out = CreateOutputsTensor(required_pv);
for (const auto& i_src : input_tensors) {
(*deconv).BindInput(context_->GetMapedTensor(i_src));
(*deconv).BindInput(context_->GetMappedTensor(i_src));
}
(*deconv).BindOutput(infer_out[0]);

Expand Down
2 changes: 1 addition & 1 deletion src/tim/transform/ops/default_layout_inference.h
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ class DefaultLayoutInfer : public OpLayoutInfer {
auto cloned_op = op_->Clone(context_->infer_graph_);

for (const auto& i_src : op_->impl()->InputsTensor()) {
(*cloned_op).BindInput(context_->GetMapedTensor(i_src));
(*cloned_op).BindInput(context_->GetMappedTensor(i_src));
}

std::vector<std::shared_ptr<IPermuteVector>> required_pv_lst;
Expand Down
4 changes: 2 additions & 2 deletions src/tim/transform/ops/depth2space_layout_inference.h
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ class DepthToSpaceLayoutInfer : public OpLayoutInfer {
auto final_pv = pv->Reverse()->Add(required_pv);
if (!final_pv->IsAligned()) {
auto perm_out =
InsertPermute(context_->GetMapedTensor(input_tensors[0]), final_pv);
InsertPermute(context_->GetMappedTensor(input_tensors[0]), final_pv);
context_->UpdateTensorMap(input_tensors[0], perm_out);
context_->SetPermuteVector(input_tensors[0], required_pv);
}
Expand All @@ -63,7 +63,7 @@ class DepthToSpaceLayoutInfer : public OpLayoutInfer {
context_->infer_graph_->CreateOperation<vx::ops::DepthToSpace>(
block_size, vx::DataLayout::WHCN);
auto out_tensor_infer = CreateOutputsTensor(required_pv);
(*space2depth).BindInput(context_->GetMapedTensor(input_tensors[0]));
(*space2depth).BindInput(context_->GetMappedTensor(input_tensors[0]));
(*space2depth).BindOutput(out_tensor_infer[0]);
context_->SetPermuteVector(op_->impl()->OutputsTensor()[0], required_pv);
// Add out tensor of src_graph into next_tensor
Expand Down
4 changes: 2 additions & 2 deletions src/tim/transform/ops/elementwise_layout_inference.h
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ class ElementWiseLayoutInfer : public OpLayoutInfer {
auto required_pv = AlignPermuteVectorForElementWise();
auto elementwise = context_->infer_graph_->CreateOperation<OpType>();
for (const auto& i_src : op_->impl()->InputsTensor()) {
(*elementwise).BindInput(context_->GetMapedTensor(i_src));
(*elementwise).BindInput(context_->GetMappedTensor(i_src));
}
auto out_infer = CreateOutputsTensor(required_pv);
(*elementwise).BindOutput(out_infer[0]);
Expand Down Expand Up @@ -120,7 +120,7 @@ class MultiplyLayoutInfer : public OpLayoutInfer {
context_->infer_graph_->CreateOperation<tim::vx::ops::Multiply>(
op_->impl()->node()->nn_param.multiply.scale);
for (const auto& i_src : op_->impl()->InputsTensor()) {
(*multiply).BindInput(context_->GetMapedTensor(i_src));
(*multiply).BindInput(context_->GetMappedTensor(i_src));
}
auto out_infer = CreateOutputsTensor(required_pv);
(*multiply).BindOutput(out_infer[0]);
Expand Down
2 changes: 1 addition & 1 deletion src/tim/transform/ops/fullyconnected_layout_inference.h
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ class FullyConnectedLayoutInfer : public OpLayoutInfer {
MakeShared(op_->impl()->OutputsTensor()[0]->GetShape().size());
auto out_infer = CreateOutputsTensor(required_pv);
for (auto in : op_->impl()->InputsTensor()) {
(*fcl).BindInput(context_->GetMapedTensor(in));
(*fcl).BindInput(context_->GetMappedTensor(in));
}
(*fcl).BindOutput(out_infer[0]);
context_->SetPermuteVector(op_->impl()->OutputsTensor()[0], required_pv);
Expand Down
2 changes: 1 addition & 1 deletion src/tim/transform/ops/gather_layout_inference.h
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ class GatherLayoutInfer : public OpLayoutInfer {
op_->impl()->node()->nn_param.gather.batch_dims);
int32_t output_rank = -1;
for (const auto& i_src : op_->impl()->InputsTensor()) {
(*gather).BindInput(context_->GetMapedTensor(i_src));
(*gather).BindInput(context_->GetMappedTensor(i_src));
output_rank += i_src->GetShape().size();
}
auto infer_out = CreateOutputsTensor(
Expand Down
2 changes: 1 addition & 1 deletion src/tim/transform/ops/gather_nd_layout_inference.h
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ class GatherNdLayoutInfer : public OpLayoutInfer {

auto gather = context_->infer_graph_->CreateOperation<vx::ops::GatherNd>();
for (const auto& i_src : op_->impl()->InputsTensor()) {
(*gather).BindInput(context_->GetMapedTensor(i_src));
(*gather).BindInput(context_->GetMappedTensor(i_src));
}
auto infer_out = CreateOutputsTensor(
context_->GetPermuteVector(op_->impl()->InputsTensor()[0]));
Expand Down
Loading
Loading