Skip to content

Commit

Permalink
Fix typos in layout inference
Browse files Browse the repository at this point in the history
Type: Code Improvement
  • Loading branch information
Xiaoran Weng authored and Xiaoran Weng committed Dec 20, 2023
1 parent d2bfb43 commit 84399b9
Show file tree
Hide file tree
Showing 47 changed files with 177 additions and 177 deletions.
2 changes: 1 addition & 1 deletion src/tim/transform/layout_infer_context.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ class LayoutInferContext {
bool IsReadyForInfer(const std::shared_ptr<vx::Operation>& op) const;
void UpdateTensorMap(const std::shared_ptr<vx::Tensor>& t_src,
const std::shared_ptr<vx::Tensor>& t_layout);
std::shared_ptr<vx::Tensor> GetMapedTensor(
std::shared_ptr<vx::Tensor> GetMappedTensor(
const std::shared_ptr<vx::Tensor>& t_src) const;
std::shared_ptr<vx::Tensor> GetMappedGraphInputTensor(
const std::shared_ptr<vx::Tensor>& t_src) const;
Expand Down
170 changes: 85 additions & 85 deletions src/tim/transform/layout_inference.cc
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ void LayoutInferContext::UpdateTensorMap(
tensor_map_[t_src] = t_layout;
}

std::shared_ptr<vx::Tensor> LayoutInferContext::GetMapedTensor(
std::shared_ptr<vx::Tensor> LayoutInferContext::GetMappedTensor(
const std::shared_ptr<vx::Tensor>& t_src) const {
auto it = tensor_map_.find(t_src);
if (it != tensor_map_.end()) {
Expand Down Expand Up @@ -190,25 +190,25 @@ void LayoutInferContext::UpdateGraphOutputMap(
graph_output_map_[o_src] = o_layout;
}

#define REGIST_LAYOUT_INFERENCE(op_idx, name) \
#define REGISTER_LAYOUT_INFERENCE(op_idx, name) \
case op_idx: { \
auto op_infer = std::make_shared<name##LayoutInfer>(op, ctx); \
op_infer->OnInputs(next_tensors); \
op_infer->OnOutputs(next_tensors); \
break; \
}

#define REGIST_REDUCE_LAYOUT_INFERENCE(op_idx) \
#define REGISTER_REDUCE_LAYOUT_INFERENCE(op_idx) \
case op_idx: { \
auto reduce_type = op->impl()->node()->nn_param.reduce.type; \
switch (reduce_type) { \
REGIST_LAYOUT_INFERENCE(VSI_NN_REDUCE_MEAN, ReduceMean); \
REGIST_LAYOUT_INFERENCE(VSI_NN_REDUCE_MAX, ReduceMax); \
REGIST_LAYOUT_INFERENCE(VSI_NN_REDUCE_MIN, ReduceMin); \
REGIST_LAYOUT_INFERENCE(VSI_NN_REDUCE_PROD, ReduceProd); \
REGIST_LAYOUT_INFERENCE(VSI_NN_REDUCE_ANY, ReduceAny); \
REGIST_LAYOUT_INFERENCE(VSI_NN_REDUCE_SUM, ReduceSum); \
REGIST_LAYOUT_INFERENCE(VSI_NN_REDUCE_ALL, ReduceAll); \
REGISTER_LAYOUT_INFERENCE(VSI_NN_REDUCE_MEAN, ReduceMean); \
REGISTER_LAYOUT_INFERENCE(VSI_NN_REDUCE_MAX, ReduceMax); \
REGISTER_LAYOUT_INFERENCE(VSI_NN_REDUCE_MIN, ReduceMin); \
REGISTER_LAYOUT_INFERENCE(VSI_NN_REDUCE_PROD, ReduceProd); \
REGISTER_LAYOUT_INFERENCE(VSI_NN_REDUCE_ANY, ReduceAny); \
REGISTER_LAYOUT_INFERENCE(VSI_NN_REDUCE_SUM, ReduceSum); \
REGISTER_LAYOUT_INFERENCE(VSI_NN_REDUCE_ALL, ReduceAll); \
default: \
VSILOGW("Op %d: Default layout inference pass for reduce.", \
reduce_type); \
Expand All @@ -217,12 +217,12 @@ void LayoutInferContext::UpdateGraphOutputMap(
break; \
}

#define REGIST_LOGICAL_LAYOUT_INFERENCE(op_idx) \
#define REGISTER_LOGICAL_LAYOUT_INFERENCE(op_idx) \
case op_idx: { \
auto logical_type = op->impl()->node()->nn_param.relational_ops.op; \
switch (logical_type) { \
REGIST_LAYOUT_INFERENCE(VSI_NN_LOGICAL_AND, LogicalAnd); \
REGIST_LAYOUT_INFERENCE(VSI_NN_LOGICAL_OR, LogicalOr); \
REGISTER_LAYOUT_INFERENCE(VSI_NN_LOGICAL_AND, LogicalAnd); \
REGISTER_LAYOUT_INFERENCE(VSI_NN_LOGICAL_OR, LogicalOr); \
default: \
VSILOGW("Op %d: Default layout inference pass for logical.", \
logical_type); \
Expand All @@ -238,80 +238,80 @@ std::vector<std::shared_ptr<vx::Tensor>> HandleLayoutInfer(
auto op_id = op->impl()->kind_;
std::vector<std::shared_ptr<vx::Tensor>> next_tensors;
switch (op_id) {
REGIST_LAYOUT_INFERENCE(VSI_NN_OP_CONV2D, Conv2d);
REGIST_LAYOUT_INFERENCE(VSI_NN_OP_GROUPED_CONV2D, GroupedConv2d);
REGIST_LAYOUT_INFERENCE(VSI_NN_OP_RELU, Relu);
REGIST_LAYOUT_INFERENCE(VSI_NN_OP_RELU1, Relu1);
REGIST_LAYOUT_INFERENCE(VSI_NN_OP_RELU6, Relu6);
REGIST_LAYOUT_INFERENCE(VSI_NN_OP_ELU, Elu);
REGIST_LAYOUT_INFERENCE(VSI_NN_OP_SIGMOID, Sigmoid);
REGIST_LAYOUT_INFERENCE(VSI_NN_OP_MISH, Mish);
REGIST_LAYOUT_INFERENCE(VSI_NN_OP_HARD_SIGMOID, HardSigmoid);
REGIST_LAYOUT_INFERENCE(VSI_NN_OP_SOFTRELU, SoftRelu);
REGIST_LAYOUT_INFERENCE(VSI_NN_OP_SWISH, HardSwish);
REGIST_LAYOUT_INFERENCE(VSI_NN_OP_TANH, Tanh);
REGIST_LAYOUT_INFERENCE(VSI_NN_OP_LEAKY_RELU, LeakyRelu);
REGIST_LAYOUT_INFERENCE(VSI_NN_OP_CONCAT, Concat);
REGIST_LAYOUT_INFERENCE(VSI_NN_OP_ADD, Add);
REGIST_LAYOUT_INFERENCE(VSI_NN_OP_SUBTRACT, Sub);
REGIST_LAYOUT_INFERENCE(VSI_NN_OP_MULTIPLY, Multiply);
REGIST_LAYOUT_INFERENCE(VSI_NN_OP_DIVIDE, Div);
REGIST_LAYOUT_INFERENCE(VSI_NN_OP_POW, Pow);
REGIST_LAYOUT_INFERENCE(VSI_NN_OP_MINIMUM, Minimum);
REGIST_LAYOUT_INFERENCE(VSI_NN_OP_MAXIMUM, Maximum);
REGIST_LAYOUT_INFERENCE(VSI_NN_OP_DATACONVERT, DataConvert);
REGIST_LAYOUT_INFERENCE(VSI_NN_OP_NEG, Neg);
REGIST_LAYOUT_INFERENCE(VSI_NN_OP_ABS, Abs);
REGIST_LAYOUT_INFERENCE(VSI_NN_OP_SIN, Sin);
REGIST_LAYOUT_INFERENCE(VSI_NN_OP_EXP, Exp);
REGIST_LAYOUT_INFERENCE(VSI_NN_OP_LOG, Log);
REGIST_LAYOUT_INFERENCE(VSI_NN_OP_SQRT, Sqrt);
REGIST_LAYOUT_INFERENCE(VSI_NN_OP_RSQRT, Rsqrt);
REGIST_LAYOUT_INFERENCE(VSI_NN_OP_SQUARE, Square);
REGIST_LAYOUT_INFERENCE(VSI_NN_OP_LOGICAL_NOT, LogicalNot);
REGIST_LAYOUT_INFERENCE(VSI_NN_OP_POOL, Pool2d);
REGIST_LAYOUT_INFERENCE(VSI_NN_OP_SOFTMAX, Softmax);
REGIST_LAYOUT_INFERENCE(VSI_NN_OP_SQUEEZE, Squeeze);
REGIST_LAYOUT_INFERENCE(VSI_NN_OP_STACK, Stack);
REGIST_LAYOUT_INFERENCE(VSI_NN_OP_SPACE2DEPTH, SpaceToDepth);
REGIST_LAYOUT_INFERENCE(VSI_NN_OP_DEPTH2SPACE, DepthToSpace);
REGIST_LAYOUT_INFERENCE(VSI_NN_OP_SPACE2BATCH, Space2Batch);
REGIST_LAYOUT_INFERENCE(VSI_NN_OP_BATCH2SPACE, Batch2Space);
REGIST_LAYOUT_INFERENCE(VSI_NN_OP_PAD, Pad);
REGIST_LAYOUT_INFERENCE(VSI_NN_OP_PAD2, PadV2);
REGIST_LAYOUT_INFERENCE(VSI_NN_OP_FCL2, FullyConnected);
REGIST_LAYOUT_INFERENCE(VSI_NN_OP_RESIZE, Resize);
REGIST_LAYOUT_INFERENCE(VSI_NN_OP_SPLIT, Split);
REGIST_LAYOUT_INFERENCE(VSI_NN_OP_STRIDED_SLICE, StridedSlice);
REGIST_LAYOUT_INFERENCE(VSI_NN_OP_LRN2, LRN);
REGIST_LAYOUT_INFERENCE(VSI_NN_OP_L2_NORMALIZE, L2Normalization);
REGIST_LAYOUT_INFERENCE(VSI_NN_OP_INSTANCE_NORM, InstanceNorm);
REGIST_LAYOUT_INFERENCE(VSI_NN_OP_ROI_ALIGN, RoiAlign);
REGIST_LAYOUT_INFERENCE(VSI_NN_OP_ROI_POOL, RoiPool);
REGIST_LAYOUT_INFERENCE(VSI_NN_OP_ADDN, AddN);
REGIST_LAYOUT_INFERENCE(VSI_NN_OP_PRELU, PRelu);
REGIST_LAYOUT_INFERENCE(VSI_NN_OP_GATHER, Gather);
REGIST_LAYOUT_INFERENCE(VSI_NN_OP_GATHER_ND, GatherNd);
REGIST_LAYOUT_INFERENCE(VSI_NN_OP_REVERSE, Reverse);
REGIST_LAYOUT_INFERENCE(VSI_NN_OP_SLICE, Slice);
REGIST_LAYOUT_INFERENCE(VSI_NN_OP_SELECT, Select);
REGIST_LAYOUT_INFERENCE(VSI_NN_OP_ARGMAX, Arg);
REGIST_LAYOUT_INFERENCE(VSI_NN_OP_ARGMIN, Arg);
REGIST_LAYOUT_INFERENCE(VSI_NN_OP_DECONVOLUTION, DeConv2d);
REGIST_LAYOUT_INFERENCE(VSI_NN_OP_BATCH_NORM, BatchNorm);
REGIST_LAYOUT_INFERENCE(VSI_NN_OP_PERMUTE, Transpose);
REGIST_LAYOUT_INFERENCE(VSI_NN_OP_CONV3D, Conv3d);
REGIST_LAYOUT_INFERENCE(VSI_NN_OP_LSTM_OVXLIB, UnidirectionalLstm);
REGIST_LAYOUT_INFERENCE(VSI_NN_OP_EXPAND_BROADCAST, Broadcast);
REGIST_LAYOUT_INFERENCE(VSI_NN_OP_UNIDIRECTIONAL_SEQUENCE_RNN,
UnidirectionalRnn);
REGIST_LAYOUT_INFERENCE(VSI_NN_OP_BIDIRECTIONAL_SEQUENCE_RNN,
BidirectionalRnn);
REGISTER_LAYOUT_INFERENCE(VSI_NN_OP_CONV2D, Conv2d);
REGISTER_LAYOUT_INFERENCE(VSI_NN_OP_GROUPED_CONV2D, GroupedConv2d);
REGISTER_LAYOUT_INFERENCE(VSI_NN_OP_RELU, Relu);
REGISTER_LAYOUT_INFERENCE(VSI_NN_OP_RELU1, Relu1);
REGISTER_LAYOUT_INFERENCE(VSI_NN_OP_RELU6, Relu6);
REGISTER_LAYOUT_INFERENCE(VSI_NN_OP_ELU, Elu);
REGISTER_LAYOUT_INFERENCE(VSI_NN_OP_SIGMOID, Sigmoid);
REGISTER_LAYOUT_INFERENCE(VSI_NN_OP_MISH, Mish);
REGISTER_LAYOUT_INFERENCE(VSI_NN_OP_HARD_SIGMOID, HardSigmoid);
REGISTER_LAYOUT_INFERENCE(VSI_NN_OP_SOFTRELU, SoftRelu);
REGISTER_LAYOUT_INFERENCE(VSI_NN_OP_SWISH, HardSwish);
REGISTER_LAYOUT_INFERENCE(VSI_NN_OP_TANH, Tanh);
REGISTER_LAYOUT_INFERENCE(VSI_NN_OP_LEAKY_RELU, LeakyRelu);
REGISTER_LAYOUT_INFERENCE(VSI_NN_OP_CONCAT, Concat);
REGISTER_LAYOUT_INFERENCE(VSI_NN_OP_ADD, Add);
REGISTER_LAYOUT_INFERENCE(VSI_NN_OP_SUBTRACT, Sub);
REGISTER_LAYOUT_INFERENCE(VSI_NN_OP_MULTIPLY, Multiply);
REGISTER_LAYOUT_INFERENCE(VSI_NN_OP_DIVIDE, Div);
REGISTER_LAYOUT_INFERENCE(VSI_NN_OP_POW, Pow);
REGISTER_LAYOUT_INFERENCE(VSI_NN_OP_MINIMUM, Minimum);
REGISTER_LAYOUT_INFERENCE(VSI_NN_OP_MAXIMUM, Maximum);
REGISTER_LAYOUT_INFERENCE(VSI_NN_OP_DATACONVERT, DataConvert);
REGISTER_LAYOUT_INFERENCE(VSI_NN_OP_NEG, Neg);
REGISTER_LAYOUT_INFERENCE(VSI_NN_OP_ABS, Abs);
REGISTER_LAYOUT_INFERENCE(VSI_NN_OP_SIN, Sin);
REGISTER_LAYOUT_INFERENCE(VSI_NN_OP_EXP, Exp);
REGISTER_LAYOUT_INFERENCE(VSI_NN_OP_LOG, Log);
REGISTER_LAYOUT_INFERENCE(VSI_NN_OP_SQRT, Sqrt);
REGISTER_LAYOUT_INFERENCE(VSI_NN_OP_RSQRT, Rsqrt);
REGISTER_LAYOUT_INFERENCE(VSI_NN_OP_SQUARE, Square);
REGISTER_LAYOUT_INFERENCE(VSI_NN_OP_LOGICAL_NOT, LogicalNot);
REGISTER_LAYOUT_INFERENCE(VSI_NN_OP_POOL, Pool2d);
REGISTER_LAYOUT_INFERENCE(VSI_NN_OP_SOFTMAX, Softmax);
REGISTER_LAYOUT_INFERENCE(VSI_NN_OP_SQUEEZE, Squeeze);
REGISTER_LAYOUT_INFERENCE(VSI_NN_OP_STACK, Stack);
REGISTER_LAYOUT_INFERENCE(VSI_NN_OP_SPACE2DEPTH, SpaceToDepth);
REGISTER_LAYOUT_INFERENCE(VSI_NN_OP_DEPTH2SPACE, DepthToSpace);
REGISTER_LAYOUT_INFERENCE(VSI_NN_OP_SPACE2BATCH, Space2Batch);
REGISTER_LAYOUT_INFERENCE(VSI_NN_OP_BATCH2SPACE, Batch2Space);
REGISTER_LAYOUT_INFERENCE(VSI_NN_OP_PAD, Pad);
REGISTER_LAYOUT_INFERENCE(VSI_NN_OP_PAD2, PadV2);
REGISTER_LAYOUT_INFERENCE(VSI_NN_OP_FCL2, FullyConnected);
REGISTER_LAYOUT_INFERENCE(VSI_NN_OP_RESIZE, Resize);
REGISTER_LAYOUT_INFERENCE(VSI_NN_OP_SPLIT, Split);
REGISTER_LAYOUT_INFERENCE(VSI_NN_OP_STRIDED_SLICE, StridedSlice);
REGISTER_LAYOUT_INFERENCE(VSI_NN_OP_LRN2, LRN);
REGISTER_LAYOUT_INFERENCE(VSI_NN_OP_L2_NORMALIZE, L2Normalization);
REGISTER_LAYOUT_INFERENCE(VSI_NN_OP_INSTANCE_NORM, InstanceNorm);
REGISTER_LAYOUT_INFERENCE(VSI_NN_OP_ROI_ALIGN, RoiAlign);
REGISTER_LAYOUT_INFERENCE(VSI_NN_OP_ROI_POOL, RoiPool);
REGISTER_LAYOUT_INFERENCE(VSI_NN_OP_ADDN, AddN);
REGISTER_LAYOUT_INFERENCE(VSI_NN_OP_PRELU, PRelu);
REGISTER_LAYOUT_INFERENCE(VSI_NN_OP_GATHER, Gather);
REGISTER_LAYOUT_INFERENCE(VSI_NN_OP_GATHER_ND, GatherNd);
REGISTER_LAYOUT_INFERENCE(VSI_NN_OP_REVERSE, Reverse);
REGISTER_LAYOUT_INFERENCE(VSI_NN_OP_SLICE, Slice);
REGISTER_LAYOUT_INFERENCE(VSI_NN_OP_SELECT, Select);
REGISTER_LAYOUT_INFERENCE(VSI_NN_OP_ARGMAX, Arg);
REGISTER_LAYOUT_INFERENCE(VSI_NN_OP_ARGMIN, Arg);
REGISTER_LAYOUT_INFERENCE(VSI_NN_OP_DECONVOLUTION, DeConv2d);
REGISTER_LAYOUT_INFERENCE(VSI_NN_OP_BATCH_NORM, BatchNorm);
REGISTER_LAYOUT_INFERENCE(VSI_NN_OP_PERMUTE, Transpose);
REGISTER_LAYOUT_INFERENCE(VSI_NN_OP_CONV3D, Conv3d);
REGISTER_LAYOUT_INFERENCE(VSI_NN_OP_LSTM_OVXLIB, UnidirectionalLstm);
REGISTER_LAYOUT_INFERENCE(VSI_NN_OP_EXPAND_BROADCAST, Broadcast);
REGISTER_LAYOUT_INFERENCE(VSI_NN_OP_UNIDIRECTIONAL_SEQUENCE_RNN,
UnidirectionalRnn);
REGISTER_LAYOUT_INFERENCE(VSI_NN_OP_BIDIRECTIONAL_SEQUENCE_RNN,
BidirectionalRnn);
#ifdef VSI_FEAT_OP_CUSTOM_TINY_YOLOV4_POSTPROCESS
REGIST_LAYOUT_INFERENCE(VSI_NN_OP_CUSTOM_TINY_YOLOV4_POSTPROCESS, Yolov4);
REGISTER_LAYOUT_INFERENCE(VSI_NN_OP_CUSTOM_TINY_YOLOV4_POSTPROCESS, Yolov4);
#endif
REGIST_LOGICAL_LAYOUT_INFERENCE(VSI_NN_OP_LOGICAL_OPS);
REGIST_REDUCE_LAYOUT_INFERENCE(VSI_NN_OP_REDUCE);
REGISTER_LOGICAL_LAYOUT_INFERENCE(VSI_NN_OP_LOGICAL_OPS);
REGISTER_REDUCE_LAYOUT_INFERENCE(VSI_NN_OP_REDUCE);
// use default layout inference
default: {
VSILOGW("Op %d: default layout inference pass.", op_id);
Expand Down
8 changes: 4 additions & 4 deletions src/tim/transform/ops/activation_layout_inference.h
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ class ActivationLayoutInfer : public OpLayoutInfer {
auto activation = op_->Clone(context_->infer_graph_);
auto out_infer = CreateOutputsTensor(input_pv);
(*activation)
.BindInput(context_->GetMapedTensor(i_src))
.BindInput(context_->GetMappedTensor(i_src))
.BindOutput(out_infer[0]);
context_->SetPermuteVector(op_->impl()->OutputsTensor()[0], input_pv);
next_tensors.push_back(op_->impl()->OutputsTensor()[0]);
Expand Down Expand Up @@ -107,7 +107,7 @@ class PReluLayoutInfer : public OpLayoutInfer {
context_->infer_graph_->CreateOperation<vx::ops::Reshape>(
boardcast_shape);
(*reshape)
.BindInput(context_->GetMapedTensor(src_slope))
.BindInput(context_->GetMappedTensor(src_slope))
.BindOutput(reshape_out);
context_->UpdateTensorMap(src_slope, reshape_out);
}
Expand All @@ -130,8 +130,8 @@ class PReluLayoutInfer : public OpLayoutInfer {
auto out_infer = CreateOutputsTensor(input_pv);

(*prelu)
.BindInput(context_->GetMapedTensor(src_input))
.BindInput(context_->GetMapedTensor(src_slope));
.BindInput(context_->GetMappedTensor(src_input))
.BindInput(context_->GetMappedTensor(src_slope));
(*prelu).BindOutput(out_infer[0]);
context_->SetPermuteVector(op_->impl()->OutputsTensor()[0], input_pv);
next_tensors.push_back(op_->impl()->OutputsTensor()[0]);
Expand Down
2 changes: 1 addition & 1 deletion src/tim/transform/ops/addn_layout_inference.h
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ class AddNLayoutInfer : public OpLayoutInfer {
auto addn = op_->Clone(context_->infer_graph_);

for (const auto& i_src : op_->impl()->InputsTensor()) {
(*addn).BindInput(context_->GetMapedTensor(i_src));
(*addn).BindInput(context_->GetMappedTensor(i_src));
}
auto infer_out = CreateOutputsTensor(required_pv);
(*addn).BindOutput(infer_out[0]);
Expand Down
2 changes: 1 addition & 1 deletion src/tim/transform/ops/arg_layout_inference.h
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ class ArgLayoutInfer : public OpLayoutInfer {

auto arg = op_->Clone(context_->infer_graph_);
auto infer_out = CreateOutputsTensor(input_pv);
(*arg).BindInput(context_->GetMapedTensor(src_input));
(*arg).BindInput(context_->GetMappedTensor(src_input));
(*arg).BindOutput(infer_out[0]);

context_->SetPermuteVector(op_->impl()->OutputsTensor()[0], input_pv);
Expand Down
4 changes: 2 additions & 2 deletions src/tim/transform/ops/batch2space_layout_inference.h
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ class Batch2SpaceLayoutInfer : public OpLayoutInfer {
auto final_pv = pv->Reverse()->Add(required_pv);
if (!final_pv->IsAligned()) {
auto perm_out =
InsertPermute(context_->GetMapedTensor(input_tensors[0]), final_pv);
InsertPermute(context_->GetMappedTensor(input_tensors[0]), final_pv);
context_->UpdateTensorMap(input_tensors[0], perm_out);
context_->SetPermuteVector(input_tensors[0], required_pv);
}
Expand All @@ -70,7 +70,7 @@ class Batch2SpaceLayoutInfer : public OpLayoutInfer {
context_->infer_graph_->CreateOperation<vx::ops::Batch2Space>(
block_size, crop, vx::DataLayout::WHCN);
auto out_tensor_infer = CreateOutputsTensor(required_pv);
(*batch2space).BindInput(context_->GetMapedTensor(input_tensors[0]));
(*batch2space).BindInput(context_->GetMappedTensor(input_tensors[0]));
(*batch2space).BindOutput(out_tensor_infer[0]);
context_->SetPermuteVector(op_->impl()->OutputsTensor()[0], required_pv);
// Add out tensor of src_graph into next_tensor
Expand Down
12 changes: 6 additions & 6 deletions src/tim/transform/ops/batchnorm_layout_inference.h
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ class BatchNormLayoutInfer : public OpLayoutInfer {
perm_out = context_->infer_graph_->CreateTensor(src_in->GetSpec(), (const void*)dataRef.data());
input_pv = MakeShared(src_in->GetShape().size());
} else {
perm_out = context_->GetMapedTensor(src_in);
perm_out = context_->GetMappedTensor(src_in);
input_pv = context_->GetPermuteVector(src_in);
context_->SetPermuteVector(src_in, input_pv);
if (idx == 0) {
Expand All @@ -73,11 +73,11 @@ class BatchNormLayoutInfer : public OpLayoutInfer {

auto batchnorm = op_->Clone(context_->infer_graph_);
auto out_tensor_infer = CreateOutputsTensor(required_pv);
(*batchnorm).BindInput(context_->GetMapedTensor(input_tensors[0]));
(*batchnorm).BindInput(context_->GetMapedTensor(input_tensors[1]));
(*batchnorm).BindInput(context_->GetMapedTensor(input_tensors[2]));
(*batchnorm).BindInput(context_->GetMapedTensor(input_tensors[3]));
(*batchnorm).BindInput(context_->GetMapedTensor(input_tensors[4]));
(*batchnorm).BindInput(context_->GetMappedTensor(input_tensors[0]));
(*batchnorm).BindInput(context_->GetMappedTensor(input_tensors[1]));
(*batchnorm).BindInput(context_->GetMappedTensor(input_tensors[2]));
(*batchnorm).BindInput(context_->GetMappedTensor(input_tensors[3]));
(*batchnorm).BindInput(context_->GetMappedTensor(input_tensors[4]));

(*batchnorm).BindOutput(out_tensor_infer[0]);
context_->SetPermuteVector(op_->impl()->OutputsTensor()[0], required_pv);
Expand Down
2 changes: 1 addition & 1 deletion src/tim/transform/ops/bidirectional_rnn_layout_inference.h
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ class BidirectionalRnnLayoutInfer : public OpLayoutInfer {


for (const auto& i_src : op_->impl()->InputsTensor()) {
(*cloned_op).BindInput(context_->GetMapedTensor(i_src));
(*cloned_op).BindInput(context_->GetMappedTensor(i_src));
}


Expand Down
2 changes: 1 addition & 1 deletion src/tim/transform/ops/broadcast_layout_inference.h
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ class BroadcastLayoutInfer : public OpLayoutInfer {
auto cloned_op = op_->Clone(context_->infer_graph_);

for (const auto& i_src : op_->impl()->InputsTensor()) {
(*cloned_op).BindInput(context_->GetMapedTensor(i_src));
(*cloned_op).BindInput(context_->GetMappedTensor(i_src));
}

std::vector<std::shared_ptr<IPermuteVector>> required_pv_lst;
Expand Down
2 changes: 1 addition & 1 deletion src/tim/transform/ops/concat_layout_inferene.h
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ class ConcatLayoutInfer : public OpLayoutInfer {
auto concat = context_->infer_graph_->CreateOperation<vx::ops::Concat>(
axis, op_->impl()->InputsTensor().size());
for (const auto& i_src : op_->impl()->InputsTensor()) {
(*concat).BindInput(context_->GetMapedTensor(i_src));
(*concat).BindInput(context_->GetMappedTensor(i_src));
}
auto out_infer = CreateOutputsTensor(required_pv);
(*concat).BindOutput(out_infer[0]);
Expand Down
Loading

0 comments on commit 84399b9

Please sign in to comment.