Skip to content

Commit

Permalink
Extracts the 'simplify pad node' optimization into its own method.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 197989813
  • Loading branch information
tensorflower-gardener committed May 25, 2018
1 parent 6b4eeb6 commit 2b99d9c
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 22 deletions.
62 changes: 40 additions & 22 deletions tensorflow/core/grappler/optimizers/constant_folding.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1913,28 +1913,14 @@ Status ConstantFolding::SimplifyNode(bool use_shape_info, NodeDef* node,
}
}

if (use_shape_info && IsPad(*node) &&
properties->GetInputProperties(node->name()).size() >= 2) {
const auto& p = properties->GetInputProperties(node->name())[1];
if (TensorShape::IsValid(p.shape()) && p.has_value()) {
Tensor paddings(p.dtype(), p.shape());
if (!paddings.FromProto(p.value())) {
return errors::InvalidArgument("Cannot parse tensor from proto: ",
p.value().DebugString());
}
// The node is replaceable iff all values in paddings are 0.
bool replaceable = true;
// The operation requires it to be int32 value so we don't check for
// 1nt64.
const auto flatten = paddings.flat<int32>();
for (int j = 0; replaceable && j < flatten.size(); ++j) {
replaceable &= flatten(j) == 0;
}
if (replaceable) {
ReplaceOperationWithIdentity(0, *properties, node, optimized_graph);
return Status::OK();
}
}
bool simplify_pad_successful = false;
Status simplify_pad_status =
SimplifyPad(*properties, use_shape_info, optimized_graph, node,
&simplify_pad_successful);
if (!simplify_pad_status.ok()) {
return simplify_pad_status;
} else if (simplify_pad_successful) {
return Status::OK();
}

if (SimplifySqueeze(*properties, use_shape_info, optimized_graph, node)) {
Expand Down Expand Up @@ -2010,6 +1996,38 @@ Status ConstantFolding::SimplifyNode(bool use_shape_info, NodeDef* node,
return Status::OK();
}

Status ConstantFolding::SimplifyPad(const GraphProperties& properties,
bool use_shape_info,
GraphDef* optimized_graph, NodeDef* node,
bool* success) {
if (use_shape_info && IsPad(*node) &&
properties.GetInputProperties(node->name()).size() >= 2) {
const auto& p = properties.GetInputProperties(node->name())[1];
if (TensorShape::IsValid(p.shape()) && p.has_value()) {
Tensor paddings(p.dtype(), p.shape());
if (!paddings.FromProto(p.value())) {
return errors::InvalidArgument("Cannot parse tensor from proto: ",
p.value().DebugString());
}
// The node is replaceable iff all values in paddings are 0.
bool replaceable = true;
// The operation requires it to be int32 value so we don't check for
// 1nt64.
const auto flatten = paddings.flat<int32>();
for (int j = 0; replaceable && j < flatten.size(); ++j) {
replaceable &= flatten(j) == 0;
}
if (replaceable) {
ReplaceOperationWithIdentity(0, properties, node, optimized_graph);
*success = true;
return Status::OK();
}
}
}
*success = false;
return Status::OK();
}

bool ConstantFolding::SimplifySqueeze(const GraphProperties& properties,
bool use_shape_info,
GraphDef* optimized_graph,
Expand Down
4 changes: 4 additions & 0 deletions tensorflow/core/grappler/optimizers/constant_folding.h
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,10 @@ class ConstantFolding : public GraphOptimizer {
bool SimplifySqueeze(const GraphProperties& properties, bool use_shape_info,
GraphDef* optimized_graph, NodeDef* node);

// Simplifies a Pad operation to an Identity operation if applicable.
Status SimplifyPad(const GraphProperties& properties, bool use_shape_info,
GraphDef* optimized_graph, NodeDef* node, bool* success);

// Points to an externally provided device or to owned_device_;
RewriterConfig::Toggle opt_level_;
DeviceBase* cpu_device_;
Expand Down

0 comments on commit 2b99d9c

Please sign in to comment.