Skip to content

Commit

Permalink
Merge pull request tensorflow#54078 from Tessil:toupstream/restrict_s…
Browse files Browse the repository at this point in the history
…ame_scale_zero_point_expand_dims

PiperOrigin-RevId: 448906718
  • Loading branch information
tensorflower-gardener committed May 16, 2022
2 parents ca85166 + cc74c44 commit 23c02b2
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 0 deletions.
10 changes: 10 additions & 0 deletions tensorflow/lite/kernels/expand_dims.cc
Original file line number Diff line number Diff line change
Expand Up @@ -73,20 +73,30 @@ TfLiteStatus GetAxisValueFromTensor(TfLiteContext* context,
TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_EQ(context, NumInputs(node), 2);
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);

const TfLiteTensor* input;
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInput, &input));
const TfLiteTensor* axis;
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kAxis, &axis));
TfLiteTensor* output;
TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, 0, &output));

output->type = input->type;
TF_LITE_ENSURE_EQ(context, input->params.scale, output->params.scale);
TF_LITE_ENSURE_EQ(context, input->params.zero_point,
output->params.zero_point);
if (input->type == kTfLiteInt16) {
TF_LITE_ENSURE_EQ(context, input->params.zero_point, 0);
}

if (IsConstantTensor(axis)) {
int axis_value;
TF_LITE_ENSURE_OK(context,
GetAxisValueFromTensor(context, *axis, &axis_value));
return ExpandTensorDim(context, *input, axis_value, output);
}
SetTensorToDynamic(output);

return kTfLiteOk;
}

Expand Down
1 change: 1 addition & 0 deletions tensorflow/lite/tools/optimize/operator_property.cc
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,7 @@ OperatorProperty GetOperatorProperty(OpVariant op_variant) {
// hence does not need to be quantized.
property.inputs = {{0, {}}};
property.outputs = {{0, {}}};
property.restrict_same_input_output_scale = true;
property.version = 1;
break;
case BuiltinOperator_FILL: {
Expand Down

0 comments on commit 23c02b2

Please sign in to comment.