Skip to content

Commit

Permalink
Refine batchmatmul mapper
Browse files Browse the repository at this point in the history
optimize code for batchmatmul mapper because tim-vx have fixed broadcast
param issue

Type: Code Improvement | Documentation
Signed-off-by: Feiyue Chen <[email protected]>
  • Loading branch information
chenfeiyue-cfy committed Jan 3, 2024
1 parent 9bcbd67 commit fc7ccac
Showing 1 changed file with 59 additions and 84 deletions.
143 changes: 59 additions & 84 deletions op_map.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1983,16 +1983,18 @@ struct BatchMatmul : public OpMapperBase<TfLiteBatchMatMulParams> {
reinterpret_cast<const TfLiteBatchMatMulParams*>(node->builtin_data);
bool adj_x = builtin->adj_x;
bool adj_y = builtin->adj_y;
auto input0_type = context->tensors[node->inputs->data[0]].type;
auto input1_type = context->tensors[node->inputs->data[1]].type;
if (context->tensors[node->outputs->data[0]].type == kTfLiteInt32) {
TFLITE_LOG_PROD(TFLITE_LOG_ERROR,
"I32 outputs type is not supported in BatchMatmul");
return false;
}
if (
(context->tensors[node->inputs->data[1]].type == kTfLiteFloat32 &&
context->tensors[node->inputs->data[0]].type == kTfLiteInt8)) {
if ((input0_type == kTfLiteFloat32 && input1_type == kTfLiteInt8) ||
(input1_type == kTfLiteFloat32 && input0_type == kTfLiteInt8)) {
TFLITE_LOG_PROD(TFLITE_LOG_ERROR,
"F32/I8 inputs type is not supported in BatchMatmul");
"Input with one being float32 and the other being int8 "
"is not supported in BatchMatmul");
return false;
}
if (adj_x && adj_y) {
Expand All @@ -2005,92 +2007,65 @@ struct BatchMatmul : public OpMapperBase<TfLiteBatchMatMulParams> {
return true;
}
bool HandleMapOp(vx::delegate::Delegate* delegate,
std::vector<std::shared_ptr<tim::vx::Tensor>>& inputs,
std::vector<std::shared_ptr<tim::vx::Tensor>>& outputs,
const void* params) override {
TFLITE_LOG(TFLITE_LOG_INFO, "Create BatchMatmul op");
const auto builtin =
reinterpret_cast<const TfLiteBatchMatMulParams*>(params);
bool adj_x = builtin->adj_x;
bool adj_y = builtin->adj_y;

std::vector<std::vector<uint32_t>> in_shape = {inputs[0]->GetShape(),
inputs[1]->GetShape()};
bool broadcast_required = false;

// Need broadcast or not
if (in_shape[0].size() != in_shape[1].size()) {
broadcast_required = true;
} else {
for (int i = 2; i < in_shape[0].size(); ++i) {
if (in_shape[0][i] != in_shape[1][i]) {
broadcast_required = true;
}
std::vector<std::shared_ptr<tim::vx::Tensor>>& inputs,
std::vector<std::shared_ptr<tim::vx::Tensor>>& outputs,
const void* params) override {
TFLITE_LOG(TFLITE_LOG_INFO, "Create BatchMatmul op");
const auto builtin =
reinterpret_cast<const TfLiteBatchMatMulParams*>(params);
bool adj_x = builtin->adj_x;
bool adj_y = builtin->adj_y;

auto in0_shape = inputs[0]->GetShape();
auto in1_shape = inputs[1]->GetShape();
bool broadcast_required =
(in0_shape.size() != in1_shape.size()) ||
!std::equal(in0_shape.begin() + 2, in0_shape.end(),
in1_shape.begin() + 2, in1_shape.end());

std::vector<std::shared_ptr<tim::vx::Tensor>> broadcast_out;
if (broadcast_required) {
std::vector<std::vector<uint32_t>> out_shape = {
{in0_shape[0], in0_shape[1]}, {in1_shape[0], in1_shape[1]}};

for (int i = 2; i < std::max(in0_shape.size(), in1_shape.size()); ++i) {
uint32_t dim1 = (i < in0_shape.size()) ? in0_shape[i] : 1;
uint32_t dim2 = (i < in1_shape.size()) ? in1_shape[i] : 1;
uint32_t max_dim = std::max(dim1, dim2);
out_shape[0].push_back(max_dim);
out_shape[1].push_back(max_dim);
}

for (size_t i = 0; i < inputs.size(); ++i) {
if (out_shape[i] != inputs[i]->GetShape()) {
tim::vx::TensorSpec spec(inputs[i]->GetSpec().AsTransientSpec());
broadcast_out.push_back(delegate->GetGraph()->CreateTensor(spec));

#if defined (BROADCAST_OPVERSION) && (BROADCAST_OPVERSION == 1)
auto op_broadcast = delegate->GetGraph()->CreateOperation<tim::vx::ops::Broadcast>(out_shape[i]);
#else
std::vector<int> broadcast_param (out_shape[i].begin(),out_shape[i].end());
auto op_broadcast = delegate->GetGraph()->CreateOperation<tim::vx::ops::Broadcast>(broadcast_param);
#endif

(*op_broadcast).BindInput(inputs[i]).BindOutput(broadcast_out[i]);
} else {
broadcast_out.push_back(inputs[i]);
}
}
} else {
broadcast_out = inputs;
}

std::vector<std::shared_ptr<tim::vx::Tensor>> broadcast_out;
if (broadcast_required) {
int out_cnt;
auto dim_iter0 = in_shape[0].begin();
auto dim_iter1 = in_shape[1].begin();
// Minimum 2 dimensions do not require broadcast
dim_iter0 += 2;
dim_iter1 += 2;
std::vector<std::vector<uint32_t>> out_shape = {
{in_shape[0][0], in_shape[0][1]}, {in_shape[1][0], in_shape[1][1]}};
while (1) {
if (dim_iter0 != in_shape[0].end() && dim_iter1 != in_shape[1].end()) {
out_shape[0].push_back(std::max(*dim_iter1, *dim_iter0));
out_shape[1].push_back(std::max(*dim_iter1, *dim_iter0));
} else {
if (in_shape[0].size() > in_shape[1].size()) {
out_shape[0].push_back(*dim_iter0);
out_shape[1].push_back(*dim_iter0);
} else {
out_shape[0].push_back(*dim_iter1);
out_shape[1].push_back(*dim_iter1);
}
}
if (dim_iter0 != in_shape[0].end()) dim_iter0++;
if (dim_iter1 != in_shape[1].end()) dim_iter1++;
// adj_x & adj_y both true are not supported
auto op = delegate->GetGraph()->CreateOperation<tim::vx::ops::Matmul>(adj_x, adj_y);
(*op).BindInputs(broadcast_out).BindOutputs(outputs);

if (dim_iter0 == in_shape[0].end() && dim_iter1 == in_shape[1].end()) {
break;
}
}
for (int i = 0; i < inputs.size(); ++i) {
if (out_shape[i] != in_shape[i]) {
tim::vx::TensorSpec spec = inputs[i]->GetSpec();
spec = spec.AsTransientSpec();
broadcast_out.push_back(delegate->GetGraph()->CreateTensor(spec));
std::vector<int32>
broadcast_param; // for Broadcast constructor parameters
for (auto iter = out_shape[i].begin(); iter != out_shape[i].end();
iter++) {
broadcast_param.push_back(*iter);
}
auto op_broadcast =
delegate->GetGraph()->CreateOperation<tim::vx::ops::Broadcast>(
broadcast_param);
(*op_broadcast).BindInput(inputs[i]).BindOutput(broadcast_out[i]);
} else {
broadcast_out.push_back(inputs[i]);
}
}
}
// adj_x & adj_y both true are not supported
auto op = delegate->GetGraph()->CreateOperation<tim::vx::ops::Matmul>(
adj_x, adj_y);
broadcast_required
? (*op).BindInput(broadcast_out[0]).BindInput(broadcast_out[1])
: (*op).BindInputs(inputs);
(*op).BindOutputs(outputs);
delegate->GetOps().push_back(std::move(op));

delegate->GetOps().push_back(std::move(op));
return true;
}

return true;
}
};

struct Rnn : public OpMapperBase<TfLiteRNNParams> {
Expand Down

0 comments on commit fc7ccac

Please sign in to comment.