From cc74c44dd0b7fb09b77f807eec9c1d91b916ad79 Mon Sep 17 00:00:00 2001 From: Thibaut Goetghebuer-Planchon Date: Fri, 7 Jan 2022 10:32:35 +0000 Subject: [PATCH] Restrict inputs and outputs scaling/zero-point to be the same for the EXPAND_DIMS op --- tensorflow/lite/kernels/expand_dims.cc | 10 ++++++++++ tensorflow/lite/tools/optimize/operator_property.cc | 1 + 2 files changed, 11 insertions(+) diff --git a/tensorflow/lite/kernels/expand_dims.cc b/tensorflow/lite/kernels/expand_dims.cc index c8d0270551c192..c6ca42936b7de6 100644 --- a/tensorflow/lite/kernels/expand_dims.cc +++ b/tensorflow/lite/kernels/expand_dims.cc @@ -73,13 +73,22 @@ TfLiteStatus GetAxisValueFromTensor(TfLiteContext* context, TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { TF_LITE_ENSURE_EQ(context, NumInputs(node), 2); TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1); + const TfLiteTensor* input; TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInput, &input)); const TfLiteTensor* axis; TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kAxis, &axis)); TfLiteTensor* output; TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, 0, &output)); + output->type = input->type; + TF_LITE_ENSURE_EQ(context, input->params.scale, output->params.scale); + TF_LITE_ENSURE_EQ(context, input->params.zero_point, + output->params.zero_point); + if (input->type == kTfLiteInt16) { + TF_LITE_ENSURE_EQ(context, input->params.zero_point, 0); + } + if (IsConstantTensor(axis)) { int axis_value; TF_LITE_ENSURE_OK(context, @@ -87,6 +96,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { return ExpandTensorDim(context, *input, axis_value, output); } SetTensorToDynamic(output); + return kTfLiteOk; } diff --git a/tensorflow/lite/tools/optimize/operator_property.cc b/tensorflow/lite/tools/optimize/operator_property.cc index a40f38fcfc3907..e262139f69700b 100644 --- a/tensorflow/lite/tools/optimize/operator_property.cc +++ b/tensorflow/lite/tools/optimize/operator_property.cc @@ -194,6 +194,7 @@ OperatorProperty GetOperatorProperty(OpVariant op_variant) { // hence does not need to be quantized. property.inputs = {{0, {}}}; property.outputs = {{0, {}}}; + property.restrict_same_input_output_scale = true; property.version = 1; break; case BuiltinOperator_FILL: {