diff --git a/tf_shell/cc/optimizers/ct_pt.cc b/tf_shell/cc/optimizers/ct_pt.cc index a49e66f..120cbbb 100644 --- a/tf_shell/cc/optimizers/ct_pt.cc +++ b/tf_shell/cc/optimizers/ct_pt.cc @@ -38,17 +38,16 @@ struct ReorderArith { int inner_pt_node_index; }; -void PrintReorderArith(RemapperContext& ctx, ReorderArith const& reorder) { - auto const* outer_node = - ctx.graph_view.GetNode(reorder.outer_node_index)->node(); - auto const* inner_node = - ctx.graph_view.GetNode(reorder.inner_node_index)->node(); +void PrintReorderArith(utils::MutableGraphView& graph_view, + ReorderArith const& reorder) { + auto const* outer_node = graph_view.GetNode(reorder.outer_node_index)->node(); + auto const* inner_node = graph_view.GetNode(reorder.inner_node_index)->node(); auto const* outer_pt_node = - ctx.graph_view.GetNode(reorder.outer_pt_node_index)->node(); + graph_view.GetNode(reorder.outer_pt_node_index)->node(); auto const* inner_ct_node = - ctx.graph_view.GetNode(reorder.inner_ct_node_index)->node(); + graph_view.GetNode(reorder.inner_ct_node_index)->node(); auto const* inner_pt_node = - ctx.graph_view.GetNode(reorder.inner_pt_node_index)->node(); + graph_view.GetNode(reorder.inner_pt_node_index)->node(); std::cout << outer_node->name() << " ( " << inner_node->name() << " ( " << inner_ct_node->name() << " , " << inner_pt_node->name() << " ), " @@ -60,9 +59,10 @@ void PrintReorderArith(RemapperContext& ctx, ReorderArith const& reorder) { // If the outer_op is add or sub, the inner_op must be add or sub. // If instead the outer_op is mul, the inner_op must be mul. // If the inner op is used elsewhere (has fanout>1), the pattern is not matched. -bool FindAddOrSub(RemapperContext& ctx, int node_index, ReorderArith* reorder) { +bool FindAddOrSub(utils::MutableGraphView& graph_view, int node_index, + ReorderArith* reorder) { // Check given node is op(ct, pt). - auto const* outer_node_view = ctx.graph_view.GetNode(node_index); + auto const* outer_node_view = graph_view.GetNode(node_index); auto const* outer_node_def = outer_node_view->node(); if (!IsAddCtPt(*outer_node_def) && !IsSubCtPt(*outer_node_def) && @@ -123,7 +123,7 @@ bool FindAddOrSub(RemapperContext& ctx, int node_index, ReorderArith* reorder) { if constexpr (debug) { std::cout << "Found pattern:"; - PrintReorderArith(ctx, new_reorder); + PrintReorderArith(graph_view, new_reorder); } *reorder = new_reorder; @@ -133,10 +133,11 @@ bool FindAddOrSub(RemapperContext& ctx, int node_index, ReorderArith* reorder) { // This function replaces the pattern outer_op(inner_op(ct, pt), pt) with // outer_op(ct, inner_op(pt, pt)). -Status ApplyReorderArith(RemapperContext* ctx, ReorderArith const& reorder, +Status ApplyReorderArith(utils::MutableGraphView& graph_view, + ReorderArith const& reorder, std::vector* nodes_to_delete) { - GraphDef const* graph = ctx->graph_view.graph(); - utils::Mutation* mutation = ctx->graph_view.GetMutationBuilder(); + GraphDef const* graph = graph_view.graph(); + utils::Mutation* mutation = graph_view.GetMutationBuilder(); Status status; // First replace the inner node with a pt pt node. @@ -210,14 +211,13 @@ Status CtPtOptimizer::Init( Status CtPtOptimizer::Optimize(Cluster* cluster, GrapplerItem const& item, GraphDef* optimized_graph) { - GrapplerItem mutable_item = item; + GrapplerItem mutable_item(item); Status status; - RemapperContext ctx(&mutable_item, &status); + utils::MutableGraphView graph_view(&mutable_item.graph, &status); TF_RETURN_IF_ERROR(status); // Topological sort and process the nodes in reverse. - TF_RETURN_IF_ERROR( - ctx.graph_view.SortTopologically(/*ignore_cycles=*/false, {})); + TF_RETURN_IF_ERROR(graph_view.SortTopologically(/*ignore_cycles=*/false, {})); bool finished = false; while (!finished) { @@ -232,17 +232,18 @@ Status CtPtOptimizer::Optimize(Cluster* cluster, GrapplerItem const& item, // Remap op( op(ct, pt), pt) to op(ct, op(pt, pt)). ReorderArith reorder; - if (FindAddOrSub(ctx, i, &reorder)) { - TF_RETURN_IF_ERROR(ApplyReorderArith(&ctx, reorder, &nodes_to_delete)); + if (FindAddOrSub(graph_view, i, &reorder)) { + TF_RETURN_IF_ERROR( + ApplyReorderArith(graph_view, reorder, &nodes_to_delete)); finished = false; } } // Remove nodes. - utils::Mutation* mutation = ctx.graph_view.GetMutationBuilder(); + utils::Mutation* mutation = graph_view.GetMutationBuilder(); for (int i = 0; i < num_nodes; ++i) { if (nodes_to_delete[i]) { - mutation->RemoveNode(ctx.graph_view.GetNode(i)); + mutation->RemoveNode(graph_view.GetNode(i)); } } TF_RETURN_IF_ERROR(mutation->Apply()); diff --git a/tf_shell/cc/optimizers/moduli_autotune.cc b/tf_shell/cc/optimizers/moduli_autotune.cc index 6238b0e..ee304f1 100644 --- a/tf_shell/cc/optimizers/moduli_autotune.cc +++ b/tf_shell/cc/optimizers/moduli_autotune.cc @@ -120,10 +120,10 @@ Status AddScalarConstNode(T value, utils::Mutation* mutation, return status; } -Status GetAutoShellContextParams(RemapperContext& ctx, +Status GetAutoShellContextParams(utils::MutableGraphView& graph_view, ShellAutoParams& params) { // Get the plaintext modulus t. - auto const* autocontext_node_view = ctx.graph_view.GetNode(kShellAutoContext); + auto const* autocontext_node_view = graph_view.GetNode(kShellAutoContext); // auto const* autocontext_node_def = autocontext_node_view->node(); auto const* cleartext_bits_node = @@ -156,14 +156,14 @@ Status GetAutoShellContextParams(RemapperContext& ctx, return OkStatus(); } -int GetMulDepth(RemapperContext& ctx) { +int GetMulDepth(utils::MutableGraphView& graph_view) { // Traverse the graph and return the maximum multiplicative depth. - int const num_nodes = ctx.graph_view.NumNodes(); + int const num_nodes = graph_view.NumNodes(); std::vector node_mul_depth(num_nodes); uint64_t max_noise = 0; for (int i = 0; i < num_nodes; ++i) { - auto const* this_node_view = ctx.graph_view.GetNode(i); + auto const* this_node_view = graph_view.GetNode(i); auto const* this_node_def = this_node_view->node(); if (IsArithmetic(*this_node_def) || IsMatMul(*this_node_def) || @@ -369,11 +369,11 @@ Status ChooseShellParams(ShellParams& params, uint64_t const total_pt_bits, // Returns the noise budget of the current node. template Status EstimateNodeNoise( - RemapperContext& ctx, int node_index, std::vector& node_noise, - ShellParams const& params, + utils::MutableGraphView& graph_view, int node_index, + std::vector& node_noise, ShellParams const& params, rlwe::RnsErrorParams>& error_params) { // Get the current node and its fanin nodes. - auto const* node_view = ctx.graph_view.GetNode(node_index); + auto const* node_view = graph_view.GetNode(node_index); auto const* node_def = node_view->node(); uint64_t* this_noise = &node_noise[node_index]; @@ -491,10 +491,11 @@ Status EstimateNodeNoise( } template -Status EstimateNoiseGrowth(RemapperContext& ctx, ShellParams const& params, +Status EstimateNoiseGrowth(utils::MutableGraphView& graph_view, + ShellParams const& params, uint64_t const noise_varaince, uint64_t* log_noise) { // Estimate the ciphertext noise growth by traversing the graph. - int const num_nodes = ctx.graph_view.NumNodes(); + int const num_nodes = graph_view.NumNodes(); std::vector node_noise(num_nodes); // Create RnsErrorParams. @@ -527,7 +528,7 @@ Status EstimateNoiseGrowth(RemapperContext& ctx, ShellParams const& params, for (int i = 0; i < num_nodes; ++i) { // Estimate the noise budget of this node. TF_RETURN_IF_ERROR( - EstimateNodeNoise(ctx, i, node_noise, params, error_params)); + EstimateNodeNoise(graph_view, i, node_noise, params, error_params)); // Update the maximum noise budget. log_max_noise = std::max(log_max_noise, node_noise[i]); @@ -544,11 +545,11 @@ Status EstimateNoiseGrowth(RemapperContext& ctx, ShellParams const& params, // Returns true if the node_index points to the outermost op of the pattern // decode(encode(a)) where a is a cleartext (tf datatype) and marks nodes to // delete accordingly. -Status ReplaceAutoparamWithContext(RemapperContext& ctx, +Status ReplaceAutoparamWithContext(utils::MutableGraphView& graph_view, ShellParams const& params, ShellAutoParams const& auto_params) { utils::MutableNodeView* autocontext_node_view = - ctx.graph_view.GetNode(kShellAutoContext); + graph_view.GetNode(kShellAutoContext); int autocontext_node_index = autocontext_node_view->node_index(); if constexpr (debug_graph) { @@ -564,7 +565,7 @@ Status ReplaceAutoparamWithContext(RemapperContext& ctx, std::string noise_var_name = "ContextImport64/noise_variance"; std::string seed_str_name = "ContextImport64/seed"; - utils::Mutation* mutation = ctx.graph_view.GetMutationBuilder(); + utils::Mutation* mutation = graph_view.GetMutationBuilder(); std::string device = autocontext_node_view->GetDevice(); TF_RETURN_IF_ERROR( AddScalarConstNode(params.log_n, mutation, log_n_name, device)); @@ -615,7 +616,7 @@ Status ReplaceAutoparamWithContext(RemapperContext& ctx, } } - mutation->RemoveNode(ctx.graph_view.GetNode(autocontext_node_index)); + mutation->RemoveNode(graph_view.GetNode(autocontext_node_index)); for (auto const& fanin : autocontext_node_view->GetRegularFanins()) { mutation->RemoveNode(fanin.node_view()); } @@ -637,14 +638,14 @@ Status ModuliAutotuneOptimizer::Init( Status ModuliAutotuneOptimizer::Optimize(Cluster* cluster, GrapplerItem const& item, GraphDef* optimized_graph) { - GrapplerItem mutable_item = item; + GrapplerItem mutable_item(item); Status status; - RemapperContext ctx(&mutable_item, &status); + utils::MutableGraphView graph_view(&mutable_item.graph, &status); TF_RETURN_IF_ERROR(status); // See if an autocontext node exists in the graph. If not, there is nothing // to do. - auto const* autocontext_view = ctx.graph_view.GetNode(kShellAutoContext); + auto const* autocontext_view = graph_view.GetNode(kShellAutoContext); if (autocontext_view == nullptr) { *optimized_graph = std::move(mutable_item.graph); return OkStatus(); @@ -654,7 +655,7 @@ Status ModuliAutotuneOptimizer::Optimize(Cluster* cluster, std::string duplicate_autocontext = kShellAutoContext; duplicate_autocontext += "_1"; auto const* duplicate_autocontext_view = - ctx.graph_view.GetNode(duplicate_autocontext); + graph_view.GetNode(duplicate_autocontext); if (duplicate_autocontext_view != nullptr) { return errors::FailedPrecondition("Multiple autocontext nodes found."); } @@ -662,17 +663,16 @@ Status ModuliAutotuneOptimizer::Optimize(Cluster* cluster, // Use GetScalarConstValue to get value of plaintext modulus, // etc. ShellAutoParams auto_params; - TF_RETURN_IF_ERROR(GetAutoShellContextParams(ctx, auto_params)); + TF_RETURN_IF_ERROR(GetAutoShellContextParams(graph_view, auto_params)); // Topological sort so all subsequent traversals are in order. - TF_RETURN_IF_ERROR( - ctx.graph_view.SortTopologically(/*ignore_cycles=*/false, {})); + TF_RETURN_IF_ERROR(graph_view.SortTopologically(/*ignore_cycles=*/false, {})); // Find the maximum multiplicative depth of the graph and use this to set // the plaintext modulus t, based on the scaling factor and depth. // Note the mul_depth + 1 accounts for the first multiplication by the // scaling factor during encoding. - int mul_depth = GetMulDepth(ctx); + int mul_depth = GetMulDepth(graph_view); uint64_t total_plaintext_bits = auto_params.cleartext_bits + std::ceil(std::log2( @@ -713,7 +713,7 @@ Status ModuliAutotuneOptimizer::Optimize(Cluster* cluster, uint64_t log_max_noise = 0; TF_RETURN_IF_ERROR(EstimateNoiseGrowth( - ctx, params, auto_params.noise_variance, &log_max_noise)); + graph_view, params, auto_params.noise_variance, &log_max_noise)); if (log_max_noise == 0) { // No encryption in this graph. Smallest parameters will do. @@ -756,14 +756,14 @@ Status ModuliAutotuneOptimizer::Optimize(Cluster* cluster, std::cout << std::endl; } - TF_RETURN_IF_ERROR(ReplaceAutoparamWithContext(ctx, params, auto_params)); + TF_RETURN_IF_ERROR( + ReplaceAutoparamWithContext(graph_view, params, auto_params)); if constexpr (debug_graph) { std::cout << "Optimized graph: " << std::endl; - int const num_nodes = ctx.graph_view.NumNodes(); + int const num_nodes = graph_view.NumNodes(); for (int i = 0; i < num_nodes; ++i) { - std::cout << ctx.graph_view.GetNode(i)->node()->DebugString() - << std::endl; + std::cout << graph_view.GetNode(i)->node()->DebugString() << std::endl; } } diff --git a/tf_shell/cc/optimizers/pt_pt.cc b/tf_shell/cc/optimizers/pt_pt.cc index ef9790d..4b4f4aa 100644 --- a/tf_shell/cc/optimizers/pt_pt.cc +++ b/tf_shell/cc/optimizers/pt_pt.cc @@ -45,13 +45,13 @@ struct ReorderArith { int pt_op_node_index; }; -void PrintReorderArith(RemapperContext& ctx, ReorderArith const& reorder) { - auto const* pt_op_node = - ctx.graph_view.GetNode(reorder.pt_op_node_index)->node(); +void PrintReorderArith(utils::MutableGraphView& graph_view, + ReorderArith const& reorder) { + auto const* pt_op_node = graph_view.GetNode(reorder.pt_op_node_index)->node(); auto const* encode_a_input_node = - ctx.graph_view.GetNode(reorder.encode_a_input_node)->node(); + graph_view.GetNode(reorder.encode_a_input_node)->node(); auto const* encode_a_node = - ctx.graph_view.GetNode(reorder.encode_a_node_index)->node(); + graph_view.GetNode(reorder.encode_a_node_index)->node(); std::cout << pt_op_node->name() << "( " << encode_a_node->name() << "(" << encode_a_input_node->name() << ")"; @@ -60,9 +60,9 @@ void PrintReorderArith(RemapperContext& ctx, ReorderArith const& reorder) { std::cout << " )" << std::endl; } else { auto const* encode_b_input_node = - ctx.graph_view.GetNode(reorder.encode_b_input_node)->node(); + graph_view.GetNode(reorder.encode_b_input_node)->node(); auto const* encode_b_node = - ctx.graph_view.GetNode(reorder.encode_b_node_index)->node(); + graph_view.GetNode(reorder.encode_b_node_index)->node(); std::cout << ", " << encode_b_node->name() << "(" << encode_b_input_node->name() << ") )" << std::endl; } @@ -71,8 +71,9 @@ void PrintReorderArith(RemapperContext& ctx, ReorderArith const& reorder) { // Returns true if the node_index points to the outermost op of the pattern // outer_plaintext_op(encode(a), encode(b)) and fills the ReorderArith struct // accordingly. -bool FindPtPt(RemapperContext& ctx, int node_index, ReorderArith* reorder) { - auto const* outer_node_view = ctx.graph_view.GetNode(node_index); +bool FindPtPt(utils::MutableGraphView& graph_view, int node_index, + ReorderArith* reorder) { + auto const* outer_node_view = graph_view.GetNode(node_index); auto const* outer_node_def = outer_node_view->node(); if (!IsReplaceableOp(*outer_node_def)) { @@ -134,7 +135,7 @@ bool FindPtPt(RemapperContext& ctx, int node_index, ReorderArith* reorder) { if constexpr (debug) { std::cout << "Found pattern: "; - PrintReorderArith(ctx, *reorder); + PrintReorderArith(graph_view, *reorder); } return true; @@ -144,10 +145,11 @@ bool FindPtPt(RemapperContext& ctx, int node_index, ReorderArith* reorder) { // encode( op(a, ) ) where <> indicate optional arguments. One might wonder // why single arg ops are re-arranged in this way, since there is no performance // gain. The reason is so subsequent ops can be optimized. -Status ApplyReorderArith(RemapperContext* ctx, ReorderArith const& reorder, +Status ApplyReorderArith(utils::MutableGraphView& graph_view, + ReorderArith const& reorder, std::vector* nodes_to_delete) { - GraphDef const* graph = ctx->graph_view.graph(); - utils::Mutation* mutation = ctx->graph_view.GetMutationBuilder(); + GraphDef const* graph = graph_view.graph(); + utils::Mutation* mutation = graph_view.GetMutationBuilder(); Status status; // First, replace the outer PtPt node with the TensorFlow equivalent and set @@ -167,7 +169,7 @@ Status ApplyReorderArith(RemapperContext* ctx, ReorderArith const& reorder, // Set the dtype of the new tf node to the same as the input. auto const* encode_a_node_view = - ctx->graph_view.GetNode(reorder.encode_a_node_index); + graph_view.GetNode(reorder.encode_a_node_index); auto const* encode_a_node_def = encode_a_node_view->node(); auto dtype = encode_a_node_def->attr().at("Dtype"); new_tf_op_def.mutable_attr()->insert({"T", dtype}); @@ -214,7 +216,7 @@ Status ApplyReorderArith(RemapperContext* ctx, ReorderArith const& reorder, (*nodes_to_delete)[reorder.pt_op_node_index] = true; // auto const* encode_a_node_view = - // ctx->graph_view.GetNode(reorder.encode_a_node_index); + // graph_view.GetNode(reorder.encode_a_node_index); if (encode_a_node_view->NumRegularFanouts() == 1) { (*nodes_to_delete)[reorder.encode_a_node_index] = true; @@ -232,7 +234,7 @@ Status ApplyReorderArith(RemapperContext* ctx, ReorderArith const& reorder, if (!reorder.single_arg_op) { auto const* encode_b_node_view = - ctx->graph_view.GetNode(reorder.encode_b_node_index); + graph_view.GetNode(reorder.encode_b_node_index); if (encode_b_node_view->NumRegularFanouts() == 1) { (*nodes_to_delete)[reorder.encode_b_node_index] = true; @@ -255,9 +257,9 @@ Status ApplyReorderArith(RemapperContext* ctx, ReorderArith const& reorder, // Returns true if the node_index points to the outermost op of the pattern // decode(encode(a)) where a is a cleartext (tf datatype) and marks nodes to // delete accordingly. -bool FindAndRemapEncDec(RemapperContext& ctx, int node_index, +bool FindAndRemapEncDec(utils::MutableGraphView& graph_view, int node_index, utils::Mutation* mutation) { - auto const* decode_node_view = ctx.graph_view.GetNode(node_index); + auto const* decode_node_view = graph_view.GetNode(node_index); auto const* decode_node_def = decode_node_view->node(); if (!IsDecode(*decode_node_def)) { @@ -281,16 +283,15 @@ bool FindAndRemapEncDec(RemapperContext& ctx, int node_index, // the decoder node is an output of the graph. Downstream nodes automatically // pick up on their new fanin inputs after the rename. utils::MutableNodeView* mutable_tf_input = - ctx.graph_view.GetNode(tf_input_node_view->node_index()); + graph_view.GetNode(tf_input_node_view->node_index()); mutation->UpdateNodeName(mutable_tf_input, decode_node_def->name()); // Delete the decode node. - mutation->RemoveNode(ctx.graph_view.GetNode(node_index)); + mutation->RemoveNode(graph_view.GetNode(node_index)); // Only delete the encode node if it is not used elsewhere. if (encode_node_view->NumRegularFanouts() == 1) { - mutation->RemoveNode( - ctx.graph_view.GetNode(encode_node_view->node_index())); + mutation->RemoveNode(graph_view.GetNode(encode_node_view->node_index())); } return true; @@ -307,14 +308,13 @@ Status PtPtOptimizer::Init( Status PtPtOptimizer::Optimize(Cluster* cluster, GrapplerItem const& item, GraphDef* optimized_graph) { - GrapplerItem mutable_item = item; + GrapplerItem mutable_item(item); Status status; - RemapperContext ctx(&mutable_item, &status); + utils::MutableGraphView graph_view(&mutable_item.graph, &status); TF_RETURN_IF_ERROR(status); // Topological sort and process the nodes in order. - TF_RETURN_IF_ERROR( - ctx.graph_view.SortTopologically(/*ignore_cycles=*/false, {})); + TF_RETURN_IF_ERROR(graph_view.SortTopologically(/*ignore_cycles=*/false, {})); bool finished = false; while (!finished) { @@ -331,17 +331,18 @@ Status PtPtOptimizer::Optimize(Cluster* cluster, GrapplerItem const& item, // indicate optional arguments. E.g. op=add has two arguments while // op=negate has only one. ReorderArith reorder; - if (FindPtPt(ctx, i, &reorder)) { - TF_RETURN_IF_ERROR(ApplyReorderArith(&ctx, reorder, &nodes_to_delete)); + if (FindPtPt(graph_view, i, &reorder)) { + TF_RETURN_IF_ERROR( + ApplyReorderArith(graph_view, reorder, &nodes_to_delete)); finished = false; } } // Remove nodes. - utils::Mutation* mutation = ctx.graph_view.GetMutationBuilder(); + utils::Mutation* mutation = graph_view.GetMutationBuilder(); for (int i = 0; i < num_nodes; ++i) { if (nodes_to_delete[i]) { - mutation->RemoveNode(ctx.graph_view.GetNode(i)); + mutation->RemoveNode(graph_view.GetNode(i)); } } TF_RETURN_IF_ERROR(mutation->Apply()); @@ -351,10 +352,10 @@ Status PtPtOptimizer::Optimize(Cluster* cluster, GrapplerItem const& item, // Since encode-decode pairs will never be nested, i.e. // decode(decode(encode(encode(...))), only one pass is necessary. { - utils::Mutation* mutation = ctx.graph_view.GetMutationBuilder(); + utils::Mutation* mutation = graph_view.GetMutationBuilder(); int const num_nodes = mutable_item.graph.node_size(); for (int i = num_nodes - 1; i >= 0; --i) { - FindAndRemapEncDec(ctx, i, mutation); + FindAndRemapEncDec(graph_view, i, mutation); } TF_RETURN_IF_ERROR(mutation->Apply()); } diff --git a/tf_shell/cc/optimizers/utils.h b/tf_shell/cc/optimizers/utils.h index 95b035e..76157a8 100644 --- a/tf_shell/cc/optimizers/utils.h +++ b/tf_shell/cc/optimizers/utils.h @@ -12,19 +12,7 @@ #include "tensorflow/core/grappler/utils/topological_sort.h" using tensorflow::NodeDef; -using tensorflow::Status; -using tensorflow::grappler::GraphProperties; -using tensorflow::grappler::GrapplerItem; -using tensorflow::grappler::utils::MutableGraphView; -using TopMutableGraphView = tensorflow::grappler::MutableGraphView; - -struct RemapperContext { - explicit RemapperContext(GrapplerItem* item, Status* status) - : graph_view(&item->graph, status), graph_properties(*item) {} - - MutableGraphView graph_view; - GraphProperties graph_properties; -}; + constexpr char kShellContext[] = "ContextImport64"; constexpr char kShellAutoContext[] = "AutoShellContext64"; diff --git a/tf_shell/python/shell_optimizers.py b/tf_shell/python/shell_optimizers.py index 1873a54..b174a27 100644 --- a/tf_shell/python/shell_optimizers.py +++ b/tf_shell/python/shell_optimizers.py @@ -25,6 +25,8 @@ ) # Based on https://github.com/openvinotoolkit/openvino_tensorflow/blob/d9dcb9d4c5932d0a8e9a3633d4134ae5841af6c1/python/openvino_tensorflow/__init__.in.py +# Anther approach using higher level APIs can be found here: +# https://stackoverflow.com/questions/74219568/optimize-and-resave-saved-model-with-grappler from tensorflow.python.framework import ops from tensorflow.core.protobuf import rewriter_config_pb2