From fc7ccaceed91996a3cd80bc6245a4ab64118eb0a Mon Sep 17 00:00:00 2001 From: Feiyue Chen Date: Wed, 3 Jan 2024 08:07:56 +0000 Subject: [PATCH] Refine batchmatmul mapper optimize code for batchmatmul mapper because tim-vx have fixed broadcast param issue Type: Code Improvement | Documentation Signed-off-by: Feiyue Chen --- op_map.cc | 143 ++++++++++++++++++++++-------------------------------- 1 file changed, 59 insertions(+), 84 deletions(-) diff --git a/op_map.cc b/op_map.cc index 34aef02..4cc6514 100644 --- a/op_map.cc +++ b/op_map.cc @@ -1983,16 +1983,18 @@ struct BatchMatmul : public OpMapperBase { reinterpret_cast(node->builtin_data); bool adj_x = builtin->adj_x; bool adj_y = builtin->adj_y; + auto input0_type = context->tensors[node->inputs->data[0]].type; + auto input1_type = context->tensors[node->inputs->data[1]].type; if (context->tensors[node->outputs->data[0]].type == kTfLiteInt32) { TFLITE_LOG_PROD(TFLITE_LOG_ERROR, "I32 outputs type is not supported in BatchMatmul"); return false; } - if ( - (context->tensors[node->inputs->data[1]].type == kTfLiteFloat32 && - context->tensors[node->inputs->data[0]].type == kTfLiteInt8)) { + if ((input0_type == kTfLiteFloat32 && input1_type == kTfLiteInt8) || + (input1_type == kTfLiteFloat32 && input0_type == kTfLiteInt8)) { TFLITE_LOG_PROD(TFLITE_LOG_ERROR, - "F32/I8 inputs type is not supported in BatchMatmul"); + "Input with one being float32 and the other being int8 " + "is not supported in BatchMatmul"); return false; } if (adj_x && adj_y) { @@ -2005,92 +2007,65 @@ struct BatchMatmul : public OpMapperBase { return true; } bool HandleMapOp(vx::delegate::Delegate* delegate, - std::vector>& inputs, - std::vector>& outputs, - const void* params) override { - TFLITE_LOG(TFLITE_LOG_INFO, "Create BatchMatmul op"); - const auto builtin = - reinterpret_cast(params); - bool adj_x = builtin->adj_x; - bool adj_y = builtin->adj_y; - - std::vector> in_shape = {inputs[0]->GetShape(), - inputs[1]->GetShape()}; - bool broadcast_required = false; - - // Need broadcast or not - if (in_shape[0].size() != in_shape[1].size()) { - broadcast_required = true; - } else { - for (int i = 2; i < in_shape[0].size(); ++i) { - if (in_shape[0][i] != in_shape[1][i]) { - broadcast_required = true; - } + std::vector>& inputs, + std::vector>& outputs, + const void* params) override { + TFLITE_LOG(TFLITE_LOG_INFO, "Create BatchMatmul op"); + const auto builtin = + reinterpret_cast(params); + bool adj_x = builtin->adj_x; + bool adj_y = builtin->adj_y; + + auto in0_shape = inputs[0]->GetShape(); + auto in1_shape = inputs[1]->GetShape(); + bool broadcast_required = + (in0_shape.size() != in1_shape.size()) || + !std::equal(in0_shape.begin() + 2, in0_shape.end(), + in1_shape.begin() + 2, in1_shape.end()); + + std::vector> broadcast_out; + if (broadcast_required) { + std::vector> out_shape = { + {in0_shape[0], in0_shape[1]}, {in1_shape[0], in1_shape[1]}}; + + for (int i = 2; i < std::max(in0_shape.size(), in1_shape.size()); ++i) { + uint32_t dim1 = (i < in0_shape.size()) ? in0_shape[i] : 1; + uint32_t dim2 = (i < in1_shape.size()) ? in1_shape[i] : 1; + uint32_t max_dim = std::max(dim1, dim2); + out_shape[0].push_back(max_dim); + out_shape[1].push_back(max_dim); + } + + for (size_t i = 0; i < inputs.size(); ++i) { + if (out_shape[i] != inputs[i]->GetShape()) { + tim::vx::TensorSpec spec(inputs[i]->GetSpec().AsTransientSpec()); + broadcast_out.push_back(delegate->GetGraph()->CreateTensor(spec)); + + #if defined (BROADCAST_OPVERSION) && (BROADCAST_OPVERSION == 1) + auto op_broadcast = delegate->GetGraph()->CreateOperation(out_shape[i]); + #else + std::vector broadcast_param (out_shape[i].begin(),out_shape[i].end()); + auto op_broadcast = delegate->GetGraph()->CreateOperation(broadcast_param); + #endif + + (*op_broadcast).BindInput(inputs[i]).BindOutput(broadcast_out[i]); + } else { + broadcast_out.push_back(inputs[i]); } } + } else { + broadcast_out = inputs; + } - std::vector> broadcast_out; - if (broadcast_required) { - int out_cnt; - auto dim_iter0 = in_shape[0].begin(); - auto dim_iter1 = in_shape[1].begin(); - // Minimum 2 dimensions do not require broadcast - dim_iter0 += 2; - dim_iter1 += 2; - std::vector> out_shape = { - {in_shape[0][0], in_shape[0][1]}, {in_shape[1][0], in_shape[1][1]}}; - while (1) { - if (dim_iter0 != in_shape[0].end() && dim_iter1 != in_shape[1].end()) { - out_shape[0].push_back(std::max(*dim_iter1, *dim_iter0)); - out_shape[1].push_back(std::max(*dim_iter1, *dim_iter0)); - } else { - if (in_shape[0].size() > in_shape[1].size()) { - out_shape[0].push_back(*dim_iter0); - out_shape[1].push_back(*dim_iter0); - } else { - out_shape[0].push_back(*dim_iter1); - out_shape[1].push_back(*dim_iter1); - } - } - if (dim_iter0 != in_shape[0].end()) dim_iter0++; - if (dim_iter1 != in_shape[1].end()) dim_iter1++; + // adj_x & adj_y both true are not supported + auto op = delegate->GetGraph()->CreateOperation(adj_x, adj_y); + (*op).BindInputs(broadcast_out).BindOutputs(outputs); - if (dim_iter0 == in_shape[0].end() && dim_iter1 == in_shape[1].end()) { - break; - } - } - for (int i = 0; i < inputs.size(); ++i) { - if (out_shape[i] != in_shape[i]) { - tim::vx::TensorSpec spec = inputs[i]->GetSpec(); - spec = spec.AsTransientSpec(); - broadcast_out.push_back(delegate->GetGraph()->CreateTensor(spec)); - std::vector - broadcast_param; // for Broadcast constructor parameters - for (auto iter = out_shape[i].begin(); iter != out_shape[i].end(); - iter++) { - broadcast_param.push_back(*iter); - } - auto op_broadcast = - delegate->GetGraph()->CreateOperation( - broadcast_param); - (*op_broadcast).BindInput(inputs[i]).BindOutput(broadcast_out[i]); - } else { - broadcast_out.push_back(inputs[i]); - } - } - } - // adj_x & adj_y both true are not supported - auto op = delegate->GetGraph()->CreateOperation( - adj_x, adj_y); - broadcast_required - ? (*op).BindInput(broadcast_out[0]).BindInput(broadcast_out[1]) - : (*op).BindInputs(inputs); - (*op).BindOutputs(outputs); + delegate->GetOps().push_back(std::move(op)); - delegate->GetOps().push_back(std::move(op)); + return true; +} - return true; - } }; struct Rnn : public OpMapperBase {