Skip to content

Commit

Permalink
[tfg] Refactor OpCatHelper to cache operation names in TFGraphDialect
Browse files Browse the repository at this point in the history
Operation names are cached in the TFGraphDialect instance (as StringAttr)
so that checking the op name of unregistered operations is faster.

The long list of check functions, name declarations, and name initializers
are placed in an include file tf_op_names.inc.

PiperOrigin-RevId: 448626406
  • Loading branch information
tensorflower-gardener committed May 14, 2022
1 parent f6639a8 commit 230e240
Show file tree
Hide file tree
Showing 13 changed files with 1,998 additions and 1,640 deletions.
2 changes: 2 additions & 0 deletions tensorflow/core/ir/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,8 @@ cc_library(
srcs = [
"interfaces.cc",
"ops.cc",
"tf_op_names.cc",
"tf_op_names.inc",
"tf_op_wrapper.cc",
"utility.cc",
],
Expand Down
1 change: 1 addition & 0 deletions tensorflow/core/ir/dialect.h
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ using mlir::tf_type::VariantType; // NOLINT
using mlir::tf_type::VersionAttr; // NOLINT

class TFGraphOpAsmInterface;
class TFOp;
} // namespace tfg
} // namespace mlir

Expand Down
8 changes: 8 additions & 0 deletions tensorflow/core/ir/dialect.td
Original file line number Diff line number Diff line change
Expand Up @@ -133,10 +133,18 @@ def TFGraphDialect : Dialect {
llvm::unique_function<void(Operation *, OpAsmPrinter &)>
getOperationPrinter(Operation *op) const override;

// Functions for checking operation categories.
#define GET_OP_CATEGORIES
#include "tensorflow/core/ir/tf_op_names.inc"

private:
// Fallback implementation of OpAsmOpInterface.
TFGraphOpAsmInterface *fallbackOpAsmInterface_ = nullptr;

// Cached TensorFlow operation names.
#define GET_OP_NAME_DECLS
#include "tensorflow/core/ir/tf_op_names.inc"

// Cached identifier for efficiency purpose.
StringAttr assigned_device_key_;
StringAttr device_key_;
Expand Down
4 changes: 4 additions & 0 deletions tensorflow/core/ir/ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,10 @@ void TFGraphDialect::initialize() {
// Create the fallback OpAsmOpInterface instance.
fallbackOpAsmInterface_ = new TFGraphOpAsmInterface;

// Initialized the cached operation names.
#define GET_OP_NAME_DEFS
#include "tensorflow/core/ir/tf_op_names.inc"

// Caching some often used context-owned informations for fast-access.
name_key_ = StringAttr::get(getContext(), getNameAttrKey());
device_key_ = StringAttr::get(getContext(), getDeviceAttrKey());
Expand Down
Loading

0 comments on commit 230e240

Please sign in to comment.