Skip to content

Commit

Permalink
Keep graph output order in layout inference
Browse files Browse the repository at this point in the history
Type: Code Improvement
  • Loading branch information
Xiaoran Weng authored and Xiaoran Weng committed Dec 20, 2023
1 parent 4fd5e3a commit d2bfb43
Show file tree
Hide file tree
Showing 3 changed files with 160 additions and 105 deletions.
14 changes: 10 additions & 4 deletions src/tim/transform/layout_infer_context.h
Original file line number Diff line number Diff line change
@@ -1,16 +1,18 @@
#ifndef TIM_VX_LAYOUT_INFER_CONTEXT_H_
#define TIM_VX_LAYOUT_INFER_CONTEXT_H_

#include "permute_vector.h"
#include "tim/transform/layout_inference.h"

#include <unordered_map>

namespace tim {
namespace transform {
namespace layout_inference_impl {
class LayoutInferContext {
public:
LayoutInferContext(const std::shared_ptr<vx::Graph>& src_graph,
std::shared_ptr<vx::Graph>& infer_graph)
: src_graph_(src_graph), infer_graph_(infer_graph) {}
std::shared_ptr<vx::Graph>& infer_graph);
void SetPermuteVector(std::shared_ptr<vx::Tensor> tensor,
std::shared_ptr<IPermuteVector> pv);
const std::shared_ptr<IPermuteVector> GetPermuteVector(
Expand All @@ -22,12 +24,16 @@ class LayoutInferContext {
const std::shared_ptr<vx::Tensor>& t_layout);
std::shared_ptr<vx::Tensor> GetMapedTensor(
const std::shared_ptr<vx::Tensor>& t_src) const;
std::shared_ptr<vx::Tensor> GetMappedGraphInputTensor(
const std::shared_ptr<vx::Tensor>& t_src) const;
std::shared_ptr<vx::Tensor> GetMappedGraphOutputTensor(
const std::shared_ptr<vx::Tensor>& t_src) const;

void UpdateGraphInputMap(const std::shared_ptr<vx::Tensor>& i_src,
const std::shared_ptr<vx::Tensor>& i_layout);

void UpdateGraphOutputMap(const std::shared_ptr<vx::Tensor>& o_src,
const std::shared_ptr<vx::Tensor>& o_layout);
const std::shared_ptr<vx::Tensor>& o_layout);

std::map<std::shared_ptr<vx::Tensor>, std::shared_ptr<vx::Tensor>>
GetGraphInputMap() const {
Expand All @@ -44,7 +50,7 @@ class LayoutInferContext {
private:
std::map<std::shared_ptr<vx::Tensor>, std::shared_ptr<IPermuteVector>>
tensor_pv_;
std::vector<std::shared_ptr<vx::Operation>> visited_op_;
std::unordered_map<std::shared_ptr<vx::Operation>, bool> op_visited_;
// tensor_in_src -> tensor_in_layout
std::map<std::shared_ptr<vx::Tensor>, std::shared_ptr<vx::Tensor>>
tensor_map_;
Expand Down
174 changes: 106 additions & 68 deletions src/tim/transform/layout_inference.cc
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@
#include "ops/roi_pool_layout_inference.h"

#include <algorithm>
#include <deque>
#include <queue>

#include "tim/vx/context.h"
#include "tim/vx/graph.h"
Expand All @@ -87,7 +87,16 @@ std::vector<std::shared_ptr<vx::Tensor>> HandleLayoutInfer(
std::shared_ptr<layout_inference_impl::LayoutInferContext>& ctx,
const std::shared_ptr<vx::Operation>& op);

// Implemention for LayoutInferContext
// Implementation for LayoutInferContext
LayoutInferContext::LayoutInferContext(
const std::shared_ptr<vx::Graph>& src_graph,
std::shared_ptr<vx::Graph>& infer_graph)
: src_graph_(src_graph), infer_graph_(infer_graph) {
for (const auto& op : src_graph->OpVector()) {
op_visited_[op] = false;
}
}

void LayoutInferContext::SetPermuteVector(std::shared_ptr<vx::Tensor> tensor,
std::shared_ptr<IPermuteVector> pv) {
if (tensor_pv_.end() != tensor_pv_.find(tensor)) {
Expand All @@ -110,27 +119,19 @@ const std::shared_ptr<IPermuteVector> LayoutInferContext::GetPermuteVector(
}

void LayoutInferContext::MarkVisited(const std::shared_ptr<vx::Operation>& op) {
if (visited_op_.end() !=
std::find(visited_op_.begin(), visited_op_.end(), op)) {
VSILOGW("The operation has been mark as visited.");
} else {
visited_op_.push_back(op);
}
op_visited_[op] = true;
}

bool LayoutInferContext::IsVisited(const std::shared_ptr<vx::Operation>& op) const {
if (visited_op_.end() !=
std::find(visited_op_.begin(), visited_op_.end(), op)) {
return true;
} else {
return false;
}
bool LayoutInferContext::IsVisited(
const std::shared_ptr<vx::Operation>& op) const {
return op_visited_.at(op);
}

bool LayoutInferContext::IsReadyForInfer(
const std::shared_ptr<vx::Operation>& op) const {
for (const auto& tensor : op->impl()->InputsTensor()) {
if (!tensor->IsConstTensor() && tensor->GetId() != (uint32_t)-1 &&
if (!tensor->IsConstTensor() &&
tensor->GetId() != static_cast<uint32_t>(-1) &&
(tensor_pv_.end() == tensor_pv_.find(tensor))) {
return false;
}
Expand All @@ -149,21 +150,43 @@ std::shared_ptr<vx::Tensor> LayoutInferContext::GetMapedTensor(
auto it = tensor_map_.find(t_src);
if (it != tensor_map_.end()) {
return it->second;
} else {
VSILOGE("Tensor has not beed inserted in tensor map.");
assert(false);
}

VSILOGE("Tensor has not beed inserted in tensor map.");
return nullptr;
}

std::shared_ptr<vx::Tensor> LayoutInferContext::GetMappedGraphInputTensor(
const std::shared_ptr<vx::Tensor>& t_src) const {
auto it = graph_input_map_.find(t_src);
if (it != tensor_map_.end()) {
return it->second;
}

VSILOGE("Tensor has not beed inserted in graph input tensor map.");
return nullptr;
}

void LayoutInferContext::UpdateGraphInputMap(const std::shared_ptr<vx::Tensor>& i_src,
const std::shared_ptr<vx::Tensor>& i_layout) {
std::shared_ptr<vx::Tensor> LayoutInferContext::GetMappedGraphOutputTensor(
const std::shared_ptr<vx::Tensor>& t_src) const {
auto it = graph_output_map_.find(t_src);
if (it != tensor_map_.end()) {
return it->second;
}

VSILOGE("Tensor has not beed inserted in graph output tensor map.");
return nullptr;
}

void LayoutInferContext::UpdateGraphInputMap(
const std::shared_ptr<vx::Tensor>& i_src,
const std::shared_ptr<vx::Tensor>& i_layout) {
graph_input_map_[i_src] = i_layout;
}

void LayoutInferContext::UpdateGraphOutputMap(const std::shared_ptr<vx::Tensor>& o_src,
const std::shared_ptr<vx::Tensor>& o_layout) {
void LayoutInferContext::UpdateGraphOutputMap(
const std::shared_ptr<vx::Tensor>& o_src,
const std::shared_ptr<vx::Tensor>& o_layout) {
graph_output_map_[o_src] = o_layout;
}

Expand All @@ -173,39 +196,40 @@ void LayoutInferContext::UpdateGraphOutputMap(const std::shared_ptr<vx::Tensor>&
op_infer->OnInputs(next_tensors); \
op_infer->OnOutputs(next_tensors); \
break; \
} \

#define REGIST_REDUCE_LAYOUT_INFERENCE(op_idx) \
case op_idx: { \
auto reduce_type = op->impl()->node()->nn_param.reduce.type; \
switch (reduce_type) { \
REGIST_LAYOUT_INFERENCE(VSI_NN_REDUCE_MEAN, ReduceMean); \
REGIST_LAYOUT_INFERENCE(VSI_NN_REDUCE_MAX, ReduceMax); \
REGIST_LAYOUT_INFERENCE(VSI_NN_REDUCE_MIN, ReduceMin); \
REGIST_LAYOUT_INFERENCE(VSI_NN_REDUCE_PROD, ReduceProd); \
REGIST_LAYOUT_INFERENCE(VSI_NN_REDUCE_ANY, ReduceAny); \
REGIST_LAYOUT_INFERENCE(VSI_NN_REDUCE_SUM, ReduceSum); \
REGIST_LAYOUT_INFERENCE(VSI_NN_REDUCE_ALL, ReduceAll); \
default: \
VSILOGW("Op %d: Default layout inference pass for reduce.", reduce_type);\
assert(false); \
} \
break; \
} \

#define REGIST_LOGICAL_LAYOUT_INFERENCE(op_idx) \
case op_idx: { \
auto logical_type = op->impl()->node()->nn_param.relational_ops.op; \
switch (logical_type) \
{ \
REGIST_LAYOUT_INFERENCE(VSI_NN_LOGICAL_AND, LogicalAnd); \
REGIST_LAYOUT_INFERENCE(VSI_NN_LOGICAL_OR, LogicalOr); \
default: \
VSILOGW("Op %d: Default layout inference pass for logical.", logical_type);\
assert(false); \
} \
break; \
} \
}

#define REGIST_REDUCE_LAYOUT_INFERENCE(op_idx) \
case op_idx: { \
auto reduce_type = op->impl()->node()->nn_param.reduce.type; \
switch (reduce_type) { \
REGIST_LAYOUT_INFERENCE(VSI_NN_REDUCE_MEAN, ReduceMean); \
REGIST_LAYOUT_INFERENCE(VSI_NN_REDUCE_MAX, ReduceMax); \
REGIST_LAYOUT_INFERENCE(VSI_NN_REDUCE_MIN, ReduceMin); \
REGIST_LAYOUT_INFERENCE(VSI_NN_REDUCE_PROD, ReduceProd); \
REGIST_LAYOUT_INFERENCE(VSI_NN_REDUCE_ANY, ReduceAny); \
REGIST_LAYOUT_INFERENCE(VSI_NN_REDUCE_SUM, ReduceSum); \
REGIST_LAYOUT_INFERENCE(VSI_NN_REDUCE_ALL, ReduceAll); \
default: \
VSILOGW("Op %d: Default layout inference pass for reduce.", \
reduce_type); \
assert(false); \
} \
break; \
}

#define REGIST_LOGICAL_LAYOUT_INFERENCE(op_idx) \
case op_idx: { \
auto logical_type = op->impl()->node()->nn_param.relational_ops.op; \
switch (logical_type) { \
REGIST_LAYOUT_INFERENCE(VSI_NN_LOGICAL_AND, LogicalAnd); \
REGIST_LAYOUT_INFERENCE(VSI_NN_LOGICAL_OR, LogicalOr); \
default: \
VSILOGW("Op %d: Default layout inference pass for logical.", \
logical_type); \
assert(false); \
} \
break; \
}

std::vector<std::shared_ptr<vx::Tensor>> HandleLayoutInfer(
std::shared_ptr<layout_inference_impl::LayoutInferContext>& ctx,
Expand Down Expand Up @@ -279,8 +303,10 @@ std::vector<std::shared_ptr<vx::Tensor>> HandleLayoutInfer(
REGIST_LAYOUT_INFERENCE(VSI_NN_OP_CONV3D, Conv3d);
REGIST_LAYOUT_INFERENCE(VSI_NN_OP_LSTM_OVXLIB, UnidirectionalLstm);
REGIST_LAYOUT_INFERENCE(VSI_NN_OP_EXPAND_BROADCAST, Broadcast);
REGIST_LAYOUT_INFERENCE(VSI_NN_OP_UNIDIRECTIONAL_SEQUENCE_RNN, UnidirectionalRnn);
REGIST_LAYOUT_INFERENCE(VSI_NN_OP_BIDIRECTIONAL_SEQUENCE_RNN, BidirectionalRnn);
REGIST_LAYOUT_INFERENCE(VSI_NN_OP_UNIDIRECTIONAL_SEQUENCE_RNN,
UnidirectionalRnn);
REGIST_LAYOUT_INFERENCE(VSI_NN_OP_BIDIRECTIONAL_SEQUENCE_RNN,
BidirectionalRnn);
#ifdef VSI_FEAT_OP_CUSTOM_TINY_YOLOV4_POSTPROCESS
REGIST_LAYOUT_INFERENCE(VSI_NN_OP_CUSTOM_TINY_YOLOV4_POSTPROCESS, Yolov4);
#endif
Expand Down Expand Up @@ -312,13 +338,13 @@ LayoutInference(
std::make_shared<layout_inference_impl::LayoutInferContext>(src_graph,
infer_graph);

std::deque<std::shared_ptr<vx::Tensor>> tensor_queue;
std::queue<std::shared_ptr<vx::Tensor>> tensor_queue;
auto graph_inputs = src_graph->InputsTensor();
for (const auto& t_src : graph_inputs) {
auto input = infer_graph->CreateTensor(t_src->GetSpec());
layout_infer_ctx->UpdateTensorMap(t_src, input);
layout_infer_ctx->UpdateGraphInputMap(t_src, input);
tensor_queue.push_back(t_src);
tensor_queue.push(t_src);
layout_infer_ctx->SetPermuteVector(
t_src, tensor_pv_map.find(t_src) != tensor_pv_map.end()
? tensor_pv_map[t_src]
Expand All @@ -329,27 +355,39 @@ LayoutInference(
for (auto const_in : const_inputs) {
std::vector<uint8_t> dataRef(const_in->GetSpec().GetByteSize());
const_in->CopyDataFromTensor(dataRef.data());
auto input =
infer_graph->CreateTensor(const_in->GetSpec(), (const void*)dataRef.data());
auto input = infer_graph->CreateTensor(const_in->GetSpec(),
(const void*)dataRef.data());
layout_infer_ctx->UpdateTensorMap(const_in, input);
tensor_queue.push_back(const_in);
tensor_queue.push(const_in);
layout_infer_ctx->SetPermuteVector(
const_in, tensor_pv_map.find(const_in) != tensor_pv_map.end()
? tensor_pv_map[const_in]
: MakeShared(const_in->GetShape().size()));
? tensor_pv_map[const_in]
: MakeShared(const_in->GetShape().size()));
}

auto graph_outputs = src_graph->OutputsTensor();
for (const auto& t_src : graph_outputs) {
auto output = infer_graph->CreateTensor(t_src->GetSpec());
layout_infer_ctx->UpdateTensorMap(t_src, output);
layout_infer_ctx->UpdateGraphOutputMap(t_src, output);
tensor_queue.push(t_src);
layout_infer_ctx->SetPermuteVector(
t_src, tensor_pv_map.find(t_src) != tensor_pv_map.end()
? tensor_pv_map[t_src]
: MakeShared(t_src->GetShape().size()));
}

while (!tensor_queue.empty()) {
auto tensor = tensor_queue.front();
tensor_queue.pop_front();
tensor_queue.pop();
const auto& consumers = src_graph->GetConsumersOp(tensor);
for (const auto& op : consumers) {
if (!layout_infer_ctx->IsVisited(op) && op->impl()->kind_ !=-1 &&
if (!layout_infer_ctx->IsVisited(op) && op->impl()->kind_ != -1 &&
layout_infer_ctx->IsReadyForInfer(op)) {
auto next_tensors =
layout_inference_impl::HandleLayoutInfer(layout_infer_ctx, op);
for (const auto& t : next_tensors) {
tensor_queue.push_back(t);
tensor_queue.push(t);
}
}
}
Expand Down
Loading

0 comments on commit d2bfb43

Please sign in to comment.