diff --git a/tensorflow/core/grappler/optimizers/constant_folding.cc b/tensorflow/core/grappler/optimizers/constant_folding.cc index 3b56f103092dfe..8cd1968df757f8 100644 --- a/tensorflow/core/grappler/optimizers/constant_folding.cc +++ b/tensorflow/core/grappler/optimizers/constant_folding.cc @@ -1913,28 +1913,14 @@ Status ConstantFolding::SimplifyNode(bool use_shape_info, NodeDef* node, } } - if (use_shape_info && IsPad(*node) && - properties->GetInputProperties(node->name()).size() >= 2) { - const auto& p = properties->GetInputProperties(node->name())[1]; - if (TensorShape::IsValid(p.shape()) && p.has_value()) { - Tensor paddings(p.dtype(), p.shape()); - if (!paddings.FromProto(p.value())) { - return errors::InvalidArgument("Cannot parse tensor from proto: ", - p.value().DebugString()); - } - // The node is replaceable iff all values in paddings are 0. - bool replaceable = true; - // The operation requires it to be int32 value so we don't check for - // 1nt64. - const auto flatten = paddings.flat(); - for (int j = 0; replaceable && j < flatten.size(); ++j) { - replaceable &= flatten(j) == 0; - } - if (replaceable) { - ReplaceOperationWithIdentity(0, *properties, node, optimized_graph); - return Status::OK(); - } - } + bool simplify_pad_successful = false; + Status simplify_pad_status = + SimplifyPad(*properties, use_shape_info, optimized_graph, node, + &simplify_pad_successful); + if (!simplify_pad_status.ok()) { + return simplify_pad_status; + } else if (simplify_pad_successful) { + return Status::OK(); } if (SimplifySqueeze(*properties, use_shape_info, optimized_graph, node)) { @@ -2010,6 +1996,38 @@ Status ConstantFolding::SimplifyNode(bool use_shape_info, NodeDef* node, return Status::OK(); } +Status ConstantFolding::SimplifyPad(const GraphProperties& properties, + bool use_shape_info, + GraphDef* optimized_graph, NodeDef* node, + bool* success) { + if (use_shape_info && IsPad(*node) && + properties.GetInputProperties(node->name()).size() >= 2) { + const auto& p = properties.GetInputProperties(node->name())[1]; + if (TensorShape::IsValid(p.shape()) && p.has_value()) { + Tensor paddings(p.dtype(), p.shape()); + if (!paddings.FromProto(p.value())) { + return errors::InvalidArgument("Cannot parse tensor from proto: ", + p.value().DebugString()); + } + // The node is replaceable iff all values in paddings are 0. + bool replaceable = true; + // The operation requires it to be int32 value so we don't check for + // 1nt64. + const auto flatten = paddings.flat(); + for (int j = 0; replaceable && j < flatten.size(); ++j) { + replaceable &= flatten(j) == 0; + } + if (replaceable) { + ReplaceOperationWithIdentity(0, properties, node, optimized_graph); + *success = true; + return Status::OK(); + } + } + } + *success = false; + return Status::OK(); +} + bool ConstantFolding::SimplifySqueeze(const GraphProperties& properties, bool use_shape_info, GraphDef* optimized_graph, diff --git a/tensorflow/core/grappler/optimizers/constant_folding.h b/tensorflow/core/grappler/optimizers/constant_folding.h index 55ad686bc58ccb..fa9249f50c1012 100644 --- a/tensorflow/core/grappler/optimizers/constant_folding.h +++ b/tensorflow/core/grappler/optimizers/constant_folding.h @@ -174,6 +174,10 @@ class ConstantFolding : public GraphOptimizer { bool SimplifySqueeze(const GraphProperties& properties, bool use_shape_info, GraphDef* optimized_graph, NodeDef* node); + // Simplifies a Pad operation to an Identity operation if applicable. + Status SimplifyPad(const GraphProperties& properties, bool use_shape_info, + GraphDef* optimized_graph, NodeDef* node, bool* success); + // Points to an externally provided device or to owned_device_; RewriterConfig::Toggle opt_level_; DeviceBase* cpu_device_;