Skip to content

Commit

Permalink
Unique Init, Prepare, Eval functions Kernels A-M (#2344)
Browse files Browse the repository at this point in the history
refactor init, prepare , and eval functions to be unique names for kernels who's name starts with the Letters A-M

BUG=[b/313963581](https://b.corp.google.com/issues/313963581)
  • Loading branch information
turbotoribio authored Dec 8, 2023
1 parent a7ff71a commit 5f5fcc8
Show file tree
Hide file tree
Showing 26 changed files with 109 additions and 92 deletions.
10 changes: 6 additions & 4 deletions tensorflow/lite/micro/kernels/batch_matmul.cc
Original file line number Diff line number Diff line change
Expand Up @@ -279,7 +279,8 @@ RuntimeShape SwapRowColumnDims(const RuntimeShape& shape) {
return swapped_shape;
}

void* Init(TfLiteContext* context, const char* buffer, size_t length) {
void* BatchMatMulInit(TfLiteContext* context, const char* buffer,
size_t length) {
// This is a builtin op, so we don't use the contents in 'buffer', if any.
// Instead, we allocate a new object to carry information from Prepare() to
// Eval().
Expand All @@ -288,7 +289,7 @@ void* Init(TfLiteContext* context, const char* buffer, size_t length) {
return micro_context->AllocatePersistentBuffer(sizeof(OpData));
}

TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TfLiteStatus BatchMatMulPrepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_EQ(context, NumInputs(node), 2);
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);

Expand Down Expand Up @@ -463,7 +464,7 @@ TfLiteStatus EvalInt16(TfLiteContext* context, const OpData& data,
// RHS <..., C, B> X LHS <..., B, A>
// where output is a C X A column-oriented, which is equivalent to
// A X C row-oriented.
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
TfLiteStatus BatchMatMulEval(TfLiteContext* context, TfLiteNode* node) {
EvalOpContext op_context(context, node);
OpData* op_data = op_context.op_data;
const TfLiteEvalTensor* lhs = op_context.lhs;
Expand Down Expand Up @@ -550,7 +551,8 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
} // namespace

TFLMRegistration Register_BATCH_MATMUL() {
return tflite::micro::RegisterOp(Init, Prepare, Eval);
return tflite::micro::RegisterOp(BatchMatMulInit, BatchMatMulPrepare,
BatchMatMulEval);
}

} // namespace tflite
7 changes: 4 additions & 3 deletions tensorflow/lite/micro/kernels/batch_to_space_nd.cc
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ constexpr int kOutputTensor = 0;
const int kInputOutputMinDimensionNum = 3;
const int kInputOutputMaxDimensionNum = 4;

TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TfLiteStatus BatchToSpaceNDPrepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_EQ(context, NumInputs(node), 3);
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);

Expand All @@ -62,7 +62,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
return kTfLiteOk;
}

TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
TfLiteStatus BatchToSpaceNDEval(TfLiteContext* context, TfLiteNode* node) {
const TfLiteEvalTensor* input =
tflite::micro::GetEvalInput(context, node, kInputTensor);
const TfLiteEvalTensor* block_shape =
Expand Down Expand Up @@ -106,7 +106,8 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
} // namespace.

TFLMRegistration Register_BATCH_TO_SPACE_ND() {
return tflite::micro::RegisterOp(nullptr, Prepare, Eval);
return tflite::micro::RegisterOp(nullptr, BatchToSpaceNDPrepare,
BatchToSpaceNDEval);
}

} // namespace tflite
8 changes: 4 additions & 4 deletions tensorflow/lite/micro/kernels/call_once.cc
Original file line number Diff line number Diff line change
Expand Up @@ -36,12 +36,12 @@ struct OpData {
bool has_run;
};

void* Init(TfLiteContext* context, const char* buffer, size_t length) {
void* CallOnceInit(TfLiteContext* context, const char* buffer, size_t length) {
TFLITE_DCHECK(context->AllocatePersistentBuffer != nullptr);
return context->AllocatePersistentBuffer(context, sizeof(OpData));
}

TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TfLiteStatus CallOncePrepare(TfLiteContext* context, TfLiteNode* node) {
OpData* op_data = reinterpret_cast<OpData*>(node->user_data);
const auto* params =
reinterpret_cast<const TfLiteCallOnceParams*>(node->builtin_data);
Expand All @@ -60,7 +60,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
return kTfLiteOk;
}

TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
TfLiteStatus CallOnceEval(TfLiteContext* context, TfLiteNode* node) {
OpData* op_data = reinterpret_cast<OpData*>(node->user_data);

// Call once only runs one time then is a no-op for every subsequent call.
Expand All @@ -82,7 +82,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
} // namespace.

TFLMRegistration Register_CALL_ONCE() {
return tflite::micro::RegisterOp(Init, Prepare, Eval);
return tflite::micro::RegisterOp(CallOnceInit, CallOncePrepare, CallOnceEval);
}

} // namespace tflite
6 changes: 3 additions & 3 deletions tensorflow/lite/micro/kernels/cast.cc
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ namespace {
constexpr int kInputTensor = 0;
constexpr int kOutputTensor = 0;

TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TfLiteStatus CastPrepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_EQ(context, NumInputs(node), 1);
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);

Expand Down Expand Up @@ -77,7 +77,7 @@ TfLiteStatus copyToTensor(TfLiteContext* context, const FromT* in,
return kTfLiteOk;
}

TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
TfLiteStatus CastEval(TfLiteContext* context, TfLiteNode* node) {
const TfLiteEvalTensor* input =
tflite::micro::GetEvalInput(context, node, kInputTensor);
TfLiteEvalTensor* output =
Expand Down Expand Up @@ -111,7 +111,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
} // namespace

TFLMRegistration Register_CAST() {
return tflite::micro::RegisterOp(nullptr, Prepare, Eval);
return tflite::micro::RegisterOp(nullptr, CastPrepare, CastEval);
}

} // namespace tflite
6 changes: 3 additions & 3 deletions tensorflow/lite/micro/kernels/ceil.cc
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ namespace {
constexpr int kInputTensor = 0;
constexpr int kOutputTensor = 0;

TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TfLiteStatus CeilPrepare(TfLiteContext* context, TfLiteNode* node) {
MicroContext* micro_context = GetMicroContext(context);

TfLiteTensor* input =
Expand All @@ -50,7 +50,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
return kTfLiteOk;
}

TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
TfLiteStatus CeilEval(TfLiteContext* context, TfLiteNode* node) {
const TfLiteEvalTensor* input =
tflite::micro::GetEvalInput(context, node, kInputTensor);
TfLiteEvalTensor* output =
Expand All @@ -67,7 +67,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
} // namespace

TFLMRegistration Register_CEIL() {
return tflite::micro::RegisterOp(nullptr, Prepare, Eval);
return tflite::micro::RegisterOp(nullptr, CeilPrepare, CeilEval);
}

} // namespace tflite
14 changes: 7 additions & 7 deletions tensorflow/lite/micro/kernels/comparisons.cc
Original file line number Diff line number Diff line change
Expand Up @@ -533,7 +533,7 @@ void* Init(TfLiteContext* context, const char* buffer, size_t length) {
return context->AllocatePersistentBuffer(context, sizeof(OpData));
}

TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TfLiteStatus ComparisonsPrepare(TfLiteContext* context, TfLiteNode* node) {
TFLITE_DCHECK(node->user_data != nullptr);
OpData* data = static_cast<OpData*>(node->user_data);

Expand Down Expand Up @@ -580,27 +580,27 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
} // namespace

TFLMRegistration Register_EQUAL() {
return tflite::micro::RegisterOp(Init, Prepare, EqualEval);
return tflite::micro::RegisterOp(Init, ComparisonsPrepare, EqualEval);
}

TFLMRegistration Register_NOT_EQUAL() {
return tflite::micro::RegisterOp(Init, Prepare, NotEqualEval);
return tflite::micro::RegisterOp(Init, ComparisonsPrepare, NotEqualEval);
}

TFLMRegistration Register_GREATER() {
return tflite::micro::RegisterOp(Init, Prepare, GreaterEval);
return tflite::micro::RegisterOp(Init, ComparisonsPrepare, GreaterEval);
}

TFLMRegistration Register_GREATER_EQUAL() {
return tflite::micro::RegisterOp(Init, Prepare, GreaterEqualEval);
return tflite::micro::RegisterOp(Init, ComparisonsPrepare, GreaterEqualEval);
}

TFLMRegistration Register_LESS() {
return tflite::micro::RegisterOp(Init, Prepare, LessEval);
return tflite::micro::RegisterOp(Init, ComparisonsPrepare, LessEval);
}

TFLMRegistration Register_LESS_EQUAL() {
return tflite::micro::RegisterOp(Init, Prepare, LessEqualEval);
return tflite::micro::RegisterOp(Init, ComparisonsPrepare, LessEqualEval);
}

} // namespace tflite
10 changes: 6 additions & 4 deletions tensorflow/lite/micro/kernels/concatenation.cc
Original file line number Diff line number Diff line change
Expand Up @@ -103,12 +103,13 @@ void EvalUnquantized(TfLiteContext* context, TfLiteNode* node) {
tflite::micro::GetTensorData<data_type>(output));
}

void* Init(TfLiteContext* context, const char* buffer, size_t length) {
void* ConcatenationInit(TfLiteContext* context, const char* buffer,
size_t length) {
TFLITE_DCHECK(context->AllocatePersistentBuffer != nullptr);
return context->AllocatePersistentBuffer(context, sizeof(OpData));
}

TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TfLiteStatus ConcatenationPrepare(TfLiteContext* context, TfLiteNode* node) {
// This function only checks the types. Additional shape validations are
// performed in the reference implementation called during Eval().
const TfLiteConcatenationParams* params =
Expand Down Expand Up @@ -214,7 +215,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
return kTfLiteOk;
}

TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
TfLiteStatus ConcatenationEval(TfLiteContext* context, TfLiteNode* node) {
const TfLiteEvalTensor* output_tensor =
tflite::micro::GetEvalOutput(context, node, kOutputTensor);
TF_LITE_ENSURE(context, output_tensor != nullptr);
Expand Down Expand Up @@ -252,7 +253,8 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
} // namespace

TFLMRegistration Register_CONCATENATION() {
return tflite::micro::RegisterOp(Init, Prepare, Eval);
return tflite::micro::RegisterOp(ConcatenationInit, ConcatenationPrepare,
ConcatenationEval);
}

} // namespace tflite
4 changes: 2 additions & 2 deletions tensorflow/lite/micro/kernels/conv.cc
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ limitations under the License.
namespace tflite {
namespace {

TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
TfLiteStatus ConvEval(TfLiteContext* context, TfLiteNode* node) {
const TfLiteEvalTensor* input =
tflite::micro::GetEvalInput(context, node, kConvInputTensor);
const TfLiteEvalTensor* filter =
Expand Down Expand Up @@ -144,7 +144,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
} // namespace

TFLMRegistration Register_CONV_2D() {
return tflite::micro::RegisterOp(ConvInit, ConvPrepare, Eval);
return tflite::micro::RegisterOp(ConvInit, ConvPrepare, ConvEval);
}

} // namespace tflite
6 changes: 3 additions & 3 deletions tensorflow/lite/micro/kernels/cumsum.cc
Original file line number Diff line number Diff line change
Expand Up @@ -104,11 +104,11 @@ TfLiteStatus CalculateOpData(TfLiteContext* context, TfLiteNode* node) {
return kTfLiteOk;
}

TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TfLiteStatus CumSumPrepare(TfLiteContext* context, TfLiteNode* node) {
return CalculateOpData(context, node);
}

TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
TfLiteStatus CumSumEval(TfLiteContext* context, TfLiteNode* node) {
const TfLiteEvalTensor* input =
tflite::micro::GetEvalInput(context, node, kInputTensor);
const TfLiteEvalTensor* axis_tensor =
Expand Down Expand Up @@ -169,7 +169,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
} // namespace

TFLMRegistration Register_CUMSUM() {
return tflite::micro::RegisterOp(nullptr, Prepare, Eval);
return tflite::micro::RegisterOp(nullptr, CumSumPrepare, CumSumEval);
}

} // namespace tflite
7 changes: 4 additions & 3 deletions tensorflow/lite/micro/kernels/depth_to_space.cc
Original file line number Diff line number Diff line change
Expand Up @@ -93,11 +93,11 @@ TfLiteStatus CalculateOpData(TfLiteContext* context, TfLiteNode* node) {
return kTfLiteOk;
}

TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TfLiteStatus DepthToSpacePrepare(TfLiteContext* context, TfLiteNode* node) {
return CalculateOpData(context, node);
}

TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
TfLiteStatus DepthToSpaceEval(TfLiteContext* context, TfLiteNode* node) {
auto* params =
reinterpret_cast<TfLiteDepthToSpaceParams*>(node->builtin_data);

Expand Down Expand Up @@ -136,7 +136,8 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
} // namespace

TFLMRegistration Register_DEPTH_TO_SPACE() {
return tflite::micro::RegisterOp(nullptr, Prepare, Eval);
return tflite::micro::RegisterOp(nullptr, DepthToSpacePrepare,
DepthToSpaceEval);
}

} // namespace tflite
8 changes: 5 additions & 3 deletions tensorflow/lite/micro/kernels/depthwise_conv.cc
Original file line number Diff line number Diff line change
Expand Up @@ -27,12 +27,13 @@ limitations under the License.
namespace tflite {
namespace {

void* Init(TfLiteContext* context, const char* buffer, size_t length) {
void* DepthwiseConvInit(TfLiteContext* context, const char* buffer,
size_t length) {
TFLITE_DCHECK(context->AllocatePersistentBuffer != nullptr);
return context->AllocatePersistentBuffer(context, sizeof(OpDataConv));
}

TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
TfLiteStatus DepthwiseConvEval(TfLiteContext* context, TfLiteNode* node) {
TFLITE_DCHECK(node->user_data != nullptr);
TFLITE_DCHECK(node->builtin_data != nullptr);

Expand Down Expand Up @@ -143,7 +144,8 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
} // namespace

TFLMRegistration Register_DEPTHWISE_CONV_2D() {
return tflite::micro::RegisterOp(Init, DepthwiseConvPrepare, Eval);
return tflite::micro::RegisterOp(DepthwiseConvInit, DepthwiseConvPrepare,
DepthwiseConvEval);
}

} // namespace tflite
13 changes: 9 additions & 4 deletions tensorflow/lite/micro/kernels/detection_postprocess.cc
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,8 @@ struct OpData {
TfLiteQuantizationParams input_anchors;
};

void* Init(TfLiteContext* context, const char* buffer, size_t length) {
void* DetectionPostProcessInit(TfLiteContext* context, const char* buffer,
size_t length) {
TFLITE_DCHECK(context->AllocatePersistentBuffer != nullptr);
OpData* op_data = nullptr;

Expand Down Expand Up @@ -149,7 +150,8 @@ void* Init(TfLiteContext* context, const char* buffer, size_t length) {
return op_data;
}

TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TfLiteStatus DetectionPostProcessPrepare(TfLiteContext* context,
TfLiteNode* node) {
auto* op_data = static_cast<OpData*>(node->user_data);

MicroContext* micro_context = GetMicroContext(context);
Expand Down Expand Up @@ -774,7 +776,8 @@ TfLiteStatus NonMaxSuppressionMultiClass(TfLiteContext* context,
return kTfLiteOk;
}

TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
TfLiteStatus DetectionPostProcessEval(TfLiteContext* context,
TfLiteNode* node) {
TF_LITE_ENSURE(context, (kBatchSize == 1));
auto* op_data = static_cast<OpData*>(node->user_data);

Expand All @@ -800,7 +803,9 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
} // namespace

TFLMRegistration* Register_DETECTION_POSTPROCESS() {
static TFLMRegistration r = tflite::micro::RegisterOp(Init, Prepare, Eval);
static TFLMRegistration r = tflite::micro::RegisterOp(
DetectionPostProcessInit, DetectionPostProcessPrepare,
DetectionPostProcessEval);
return &r;
}

Expand Down
8 changes: 4 additions & 4 deletions tensorflow/lite/micro/kernels/div.cc
Original file line number Diff line number Diff line change
Expand Up @@ -65,12 +65,12 @@ TfLiteStatus CalculateOpDataDiv(TfLiteContext* context, TfLiteTensor* input1,
return kTfLiteOk;
}

void* Init(TfLiteContext* context, const char* buffer, size_t length) {
void* DivInit(TfLiteContext* context, const char* buffer, size_t length) {
TFLITE_DCHECK(context->AllocatePersistentBuffer != nullptr);
return context->AllocatePersistentBuffer(context, sizeof(OpDataDiv));
}

TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TfLiteStatus DivPrepare(TfLiteContext* context, TfLiteNode* node) {
TFLITE_DCHECK(node->user_data != nullptr);
TFLITE_DCHECK(node->builtin_data != nullptr);

Expand Down Expand Up @@ -179,7 +179,7 @@ TfLiteStatus EvalQuantized(TfLiteContext* context, TfLiteNode* node,
return kTfLiteOk;
}

TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
TfLiteStatus DivEval(TfLiteContext* context, TfLiteNode* node) {
TFLITE_DCHECK(node->builtin_data != nullptr);
auto* params = static_cast<TfLiteDivParams*>(node->builtin_data);
TFLITE_DCHECK(node->user_data != nullptr);
Expand Down Expand Up @@ -213,7 +213,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
} // namespace

TFLMRegistration Register_DIV() {
return tflite::micro::RegisterOp(Init, Prepare, Eval);
return tflite::micro::RegisterOp(DivInit, DivPrepare, DivEval);
}

} // namespace tflite
Loading

0 comments on commit 5f5fcc8

Please sign in to comment.