From 0ed111d038a62d42849b488406d900286690aa64 Mon Sep 17 00:00:00 2001 From: George Vasilakopoulos Date: Tue, 27 Feb 2024 19:22:17 +0000 Subject: [PATCH 01/53] Migration Changes --- .../core/common_runtime/graph_constructor.cc | 96 ++- .../common_runtime/graph_execution_state.cc | 26 + tensorflow/core/common_runtime/graph_view.h | 10 + .../immutable_executor_state.cc | 47 ++ .../core/common_runtime/propagator_state.cc | 106 ++- .../core/common_runtime/propagator_state.h | 18 +- tensorflow/core/graph/graph.cc | 2 + tensorflow/core/graph/graph.h | 7 + .../core/grappler/function_transformation.h | 41 ++ tensorflow/core/grappler/op_types.cc | 12 +- tensorflow/core/grappler/op_types.h | 2 + tensorflow/core/grappler/optimizers/BUILD | 41 ++ .../optimizers/function_transformation.cc | 650 ++++++++++++++++++ .../function_transformation_test.cc | 57 ++ .../grappler/optimizers/meta_optimizer.cc | 17 +- .../core/grappler/utils/topological_sort.cc | 36 + tensorflow/core/kernels/BUILD | 9 + .../core/kernels/function_control_ops.cc | 191 +++++ .../core/kernels/function_control_ops.h | 47 ++ tensorflow/core/ops/BUILD | 2 + tensorflow/core/ops/function_control_ops.cc | 116 ++++ .../core/protobuf/rewriter_config.proto | 2 + 22 files changed, 1497 insertions(+), 38 deletions(-) create mode 100644 tensorflow/core/grappler/function_transformation.h create mode 100644 tensorflow/core/grappler/optimizers/function_transformation.cc create mode 100644 tensorflow/core/grappler/optimizers/function_transformation_test.cc create mode 100644 tensorflow/core/kernels/function_control_ops.cc create mode 100644 tensorflow/core/kernels/function_control_ops.h create mode 100644 tensorflow/core/ops/function_control_ops.cc diff --git a/tensorflow/core/common_runtime/graph_constructor.cc b/tensorflow/core/common_runtime/graph_constructor.cc index 66109aee89eaa9..b45b76f0ffc776 100644 --- a/tensorflow/core/common_runtime/graph_constructor.cc +++ b/tensorflow/core/common_runtime/graph_constructor.cc @@ -63,6 +63,15 @@ namespace { // can skip expensive duplicates check in 'AddControlEdge'. static constexpr const bool kDoNotCheckDuplicates = true; +inline bool IsCall(const NodeDef& node_def){ + return node_def.op() == "Call" || node_def.op() == "RefCall"; +} + +inline bool IsReturn(const NodeDef& node_def){ + return node_def.op() == "Return" || node_def.op() == "RefReturn"; +} + + inline bool IsMerge(const NodeDef& node_def) { return node_def.op() == "Merge" || node_def.op() == "RefMerge" || node_def.op() == "_XlaMerge"; @@ -201,6 +210,7 @@ class GraphConstructor { TF_RETURN_IF_ERROR(EnsureNoNameCollisions()); TF_RETURN_IF_ERROR(ValidateInputMapAndControlDependencies()); TF_RETURN_IF_ERROR(BuildNodeIndex()); + TF_RETURN_IF_ERROR(PopulateFunctionReturningNodes()); TF_RETURN_IF_ERROR(InitFromEdges()); // NOTE: Convert() invokes `consume_node_def()` on each node in the input @@ -228,6 +238,7 @@ class GraphConstructor { Status PopulateReturnTensors(); Status PopulateReturnNodes(); Status PopulateMissingUnusedInputMapKeys(); + Status PopulateFunctionReturningNodes(); FunctionDefLibraryStackTraces CreateStackTracesForFunctionDefLibrary( const FunctionDefLibrary& library) const; @@ -261,6 +272,10 @@ class GraphConstructor { void AddPrefixToNodeDef(const std::vector& input_already_exists, NodeDef* node_def); + bool IsReturningNode(const NodeDef& node_def){ + return (function_returning_nodes_.find(node_def.name()) != function_returning_nodes_.end()); + } + // Modifies `node_def` if its name isn't unique, or if any of its inputs' // names have been uniquified. This must be called in topological order on all // nodes. @@ -405,6 +420,7 @@ class GraphConstructor { int dst_index; }; std::vector back_edges_; + std::unordered_set function_returning_nodes_; GraphConstructor(const GraphConstructor&) = delete; void operator=(const GraphConstructor&) = delete; @@ -646,6 +662,44 @@ Status GraphConstructor::EnsureNoNameCollisions() { return absl::OkStatus(); } +Status GraphConstructor::PopulateFunctionReturningNodes() { + std::unordered_map> returning_nodes; + for (int n = 0; n < node_def_count(); ++n) { + const NodeDef& node_def = get_node_def(); + if (IsReturn(node_def)){ + // Nodes that send their output to "Return" nodes are + // function Returning Nodes and in case of recursive functions + // those nodes are part of graph cycles. + for (const auto& input_name : node_def.input()) { + // In order to detect the recursion cycles we depend on + // the fact that a recursive function's returning node, + // will be sending outputs to at least 2 "Return" nodes + // with different "call_id" attributes (same "call_id" + // attrs would mean that they belong in the same function call + // but they correspond to different function outputs) + if (!absl::StartsWith(input_name, "^")) { + string prevNode = input_name; + size_t pos = input_name.find(":"); + if (pos != std::string::npos) + prevNode = input_name.substr(0, pos); + + + int call_id; + GetNodeAttr(AttrSlice(node_def), "call_id", &call_id); + returning_nodes[prevNode].emplace(call_id); + } + } + } + } + for (auto& retnode : returning_nodes) { + if (retnode.second.size() > 1) { + // Detected Cycle + function_returning_nodes_.insert(retnode.first); + } + } + return OkStatus(); +} + Status GraphConstructor::ValidateInputMapAndControlDependencies() { for (const auto& mapping : opts_.input_map) { TensorId src = mapping.first; @@ -729,15 +783,28 @@ Status GraphConstructor::InitFromEdges() { } } + gt1::FlatSet call_nodes; + for (int n = 0; n < node_def_count(); ++n) { + const NodeDef& node_def = get_node_def(n); + if (IsCall(node_def)) { + call_nodes.insert(node_def.name()); + } + } + + + + // Parse the inputs for each node. for (int n = 0; n < num_nodes; ++n) { const NodeDef& node_def = get_node_def(n); int pending_count = node_def.input_size(); - if (IsMerge(node_def)) { - // Cycles in the graph are only allowed for while loops. A while loop is - // identified by an edge from a NextIteration node to a Merge node. For - // such Merge nodes, only wait for one non-control input before - // considering the node ready to process in Convert(). + if (IsMerge(node_def) && !IsReturningNode(node_def)) { + // Cycles in the graph are only allowed for while loops and recursion. + // A while loop is identified by an edge from a NextIteration node to a Merge node. + // A recursion is identified by an edge from a Call Node to a Merge node + // In recursion, function returning nodes also participate in a cycle + // For such Merge nodes, and for function returning nodes only wait for + // one non-control input before considering the node ready to process in Convert(). int32_t num_control_edges = 0; bool has_loop_back_edge = false; for (int i = 0; i < node_def.input_size(); ++i) { @@ -747,7 +814,9 @@ Status GraphConstructor::InitFromEdges() { } else { TensorId id(ParseTensorName(input_name)); if (next_iteration_nodes.find(string(id.first)) != - next_iteration_nodes.end()) { + next_iteration_nodes.end()|| + call_nodes.find(string(id.first)) != + call_nodes.end()) { has_loop_back_edge = true; } } @@ -755,7 +824,16 @@ Status GraphConstructor::InitFromEdges() { if (has_loop_back_edge) { pending_count = num_control_edges + 1; } - } + } else if (IsReturningNode(node_def)) { + int num_control_edges = 0; + for (int i = 0; i < node_def.input_size(); ++i) { + StringPiece input_name(node_def.input(i)); + if (absl::StartsWith(input_name, "^")) { + num_control_edges++; + } + } + pending_count = num_control_edges + 1; + } for (int i = 0; i < node_def.input_size(); ++i) { StringPiece input_name = node_def.input(i); TensorId id(ParseTensorName(input_name)); @@ -1277,10 +1355,10 @@ Status GraphConstructor::Convert() { inputs.emplace_back(string(tensor_id.node()), src_node, src_index); } - if (has_data_back_edge && !IsMerge(node_def)) { + if (has_data_back_edge && !IsMerge(node_def) && !IsReturningNode(node_def)) { return errors::InvalidArgument( "Node '", node_def.name(), - "' had a back edge, but only Merge nodes can have back edges."); + "' had a back edge, but only Merge and returning nodes can have back edges."); } Node* node; diff --git a/tensorflow/core/common_runtime/graph_execution_state.cc b/tensorflow/core/common_runtime/graph_execution_state.cc index c2e115b80c3cb9..19cfd1c4fbaff5 100644 --- a/tensorflow/core/common_runtime/graph_execution_state.cc +++ b/tensorflow/core/common_runtime/graph_execution_state.cc @@ -61,6 +61,9 @@ limitations under the License. #include "tensorflow/core/grappler/optimizers/meta_optimizer.h" #endif // IS_MOBILE_PLATFORM +#include "tensorflow/core/util/event.pb.h" +#include "tensorflow/core/util/events_writer.h" + namespace tensorflow { namespace { @@ -849,6 +852,29 @@ Status GraphExecutionState::OptimizeGraph( for (Node* node : optimized_graph->get()->nodes()) { node->set_assigned_device_name(node->requested_device()); } + + /*******************************************************************************************/ + // Write an event, so that we can visualize this optimized graph in tensorboard + EventsWriter writer("Fully_Optimized"); + Event event; + event.set_wall_time(1234); + event.set_step(34); + + const size_t proto_size = new_graph.ByteSizeLong(); + void* buf = port::Malloc(proto_size); + if (buf == nullptr) { + return tensorflow::errors::ResourceExhausted("Failed to allocate memory to serialize message of type '" + ,new_graph.GetTypeName(), "' and size ", proto_size); + } + new_graph.SerializeToArray(buf, proto_size); + const void* bf = buf; + event.set_graph_def(bf, proto_size); + writer.WriteEvent(event); + /*******************************************************************************************/ + printf("Transformation passed successfully"); + + + return absl::OkStatus(); } else { return errors::InvalidArgument("Meta Optimizer disabled"); diff --git a/tensorflow/core/common_runtime/graph_view.h b/tensorflow/core/common_runtime/graph_view.h index ed9b14cfa1f73d..e3396f7d36ccd3 100644 --- a/tensorflow/core/common_runtime/graph_view.h +++ b/tensorflow/core/common_runtime/graph_view.h @@ -67,10 +67,14 @@ struct NodeItem { bool is_constant_enter : 1; // True iff IsEnter(node) and // node->GetAttr("is_constant") == true. bool is_exit : 1; // True iff IsExit(node) + bool is_call : 1; // True iff IsCall(node) + bool is_return : 1; // True iff IsReturn(node) bool is_control_trigger : 1; // True iff IsControlTrigger(node) bool is_source : 1; // True iff IsSource(node) // True iff IsEnter(node) || IsExit(node) || IsNextIteration(node) bool is_enter_exit_or_next_iter : 1; + // True iff IsCall(node) || IsReturn(node) + bool is_call_or_return : 1; bool is_transfer_node : 1; // True iff IsTransferNode(node) bool is_initialization_op : 1; // True iff IsInitializationOp(node) bool is_recv_or_switch : 1; // True iff IsRecv(node) || IsSwitch(node) @@ -107,6 +111,12 @@ struct NodeItem { // Number of output control edges. int32 num_output_control_edges; + string frame_name; // cache the attribute if is_enter | is-exit | is_call | is_return + string dyn_frame_name; // cache the attribute if is_enter | is-exit | is_call | is_return + + int call_id = -1; + + // If non-null, contains an array of num_outputs bools, where the ith bool // is true if and only if the ith output is consumed by another node. std::unique_ptr outputs_required; diff --git a/tensorflow/core/common_runtime/immutable_executor_state.cc b/tensorflow/core/common_runtime/immutable_executor_state.cc index e3a2435505e041..1ef9ad362dd9c1 100644 --- a/tensorflow/core/common_runtime/immutable_executor_state.cc +++ b/tensorflow/core/common_runtime/immutable_executor_state.cc @@ -96,6 +96,8 @@ Status ImmutableExecutorState::Initialize(const Graph& graph) { pending_ids_.resize(gview_.num_nodes()); + std::unordered_map input_count; + // Preprocess every node in the graph to create an instance of op // kernel for each node. requires_control_flow_ = false; @@ -103,6 +105,8 @@ Status ImmutableExecutorState::Initialize(const Graph& graph) { if (IsSink(n)) continue; if (IsSwitch(n) || IsMerge(n) || IsEnter(n) || IsExit(n)) { requires_control_flow_ = true; + } else if(IsCall(n) || IsReturn(n)){ + requires_control_flow_ = true; } else if (IsRecv(n)) { // A Recv node from a different device may produce dead tensors from // non-local control-flow nodes. @@ -186,6 +190,9 @@ Status ImmutableExecutorState::Initialize(const Graph& graph) { } else { item->is_constant_enter = false; } + item->is_call = IsCall(n); + item->is_return = IsReturn(n); + item->is_call_or_return = (IsCall(n) || IsReturn(n)); item->is_exit = IsExit(n); item->is_control_trigger = IsControlTrigger(n); item->is_source = IsSource(n); @@ -217,6 +224,19 @@ Status ImmutableExecutorState::Initialize(const Graph& graph) { string enter_name; TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), "frame_name", &enter_name)); EnsureFrameInfo(enter_name)->input_count++; + item->frame_name = enter_name; + item->dyn_frame_name = enter_name; + } + if(item->is_call_or_return){ + TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), "frame_name", &item->frame_name)); + TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), "call_id", &item->call_id)); + item->dyn_frame_name = strings::StrCat(item->call_id); + } + if (item->is_call) { + input_count[item->dyn_frame_name]++; + // The following assumes that all the calls of same function have the same number of inputs + // which is of course apparent for a well-formed graph (produced by the transformation) + EnsureFrameInfo(item->frame_name)->input_count = input_count[item->dyn_frame_name]; } // Record information about whether each output of the op is used. @@ -303,6 +323,7 @@ Status ImmutableExecutorState::BuildControlFlowInfo(const Graph* g, } } + std::unordered_map call_id_to_call_node_id; while (!ready.empty()) { Node* curr_node = ready.front(); int curr_id = curr_node->id(); @@ -323,6 +344,31 @@ Status ImmutableExecutorState::BuildControlFlowInfo(const Graph* g, } frame_name = cf_info->frame_names[parent->id()]; parent = parent_nodes[parent->id()]; + } else if (IsCall(curr_node)) { + TF_RETURN_IF_ERROR( + GetNodeAttr(curr_node->attrs(), "frame_name", &frame_name)); + + int call_id; + TF_RETURN_IF_ERROR( + GetNodeAttr(curr_node->attrs(),"call_id", &call_id)); + // we assume that call_id is unique and we don't need to concat with frame_name + // to make it unique. + call_id_to_call_node_id.emplace(call_id, curr_id); + parent = curr_node; + } else if (IsReturn(curr_node)) { + int call_id; + TF_RETURN_IF_ERROR( + GetNodeAttr(curr_node->attrs(), "call_id", &call_id)); + + auto it = call_id_to_call_node_id.find(call_id); + if (it != call_id_to_call_node_id.end()) { + int call_node_id = it->second; + parent = parent_nodes[call_node_id]; + frame_name = cf_info->frame_names[call_node_id]; + } else { + ready.push_back(curr_node); + continue; + } } else { parent = parent_nodes[curr_id]; frame_name = cf_info->frame_names[curr_id]; @@ -331,6 +377,7 @@ Status ImmutableExecutorState::BuildControlFlowInfo(const Graph* g, for (const Edge* out_edge : curr_node->out_edges()) { Node* out = out_edge->dst(); if (IsSink(out)) continue; + if (IsReturn(out) && out_edge->IsControlEdge()) continue; const int out_id = out->id(); // Add to ready queue if not visited. diff --git a/tensorflow/core/common_runtime/propagator_state.cc b/tensorflow/core/common_runtime/propagator_state.cc index 9a365177770d4a..4ec560ce2d7aa6 100644 --- a/tensorflow/core/common_runtime/propagator_state.cc +++ b/tensorflow/core/common_runtime/propagator_state.cc @@ -42,13 +42,13 @@ PropagatorState::PropagatorState(const ImmutableExecutorState& immutable_state, 0, new PropagatorState::IterationState(0, root_frame_->pending_counts, root_frame_->total_input_tensors)); - outstanding_frames_.emplace(root_frame_->frame_id, root_frame_); + // outstanding_frames_.emplace(root_frame_->frame_id, root_frame_); } PropagatorState::~PropagatorState() { - for (auto name_frame : outstanding_frames_) { - delete name_frame.second; - } + // for (auto name_frame : outstanding_frames_) { + // delete name_frame.second; + // } } void PropagatorState::ActivateRoots(gtl::ArraySlice roots, @@ -89,7 +89,19 @@ void PropagatorState::PropagateOutputs(const TaggedNode& tagged_node, FrameState* output_frame = input_frame; IterationState* output_iter = input_iter; - if (!item->is_enter_exit_or_next_iter) { + // if (!item->is_enter_exit_or_next_iter) { + // if (vlog_) { + // VLOG(2) << "Propagate Outputs: " << node->name(); + // VLOG(2) << "Frame: " << input_frame->frame_name; + // } + // printf("Propagate Outputs: %s, am i alive? %d\n", node->name().c_str(), !is_dead); + // printf("Frame: %s\n", input_frame->frame_name.c_str()); + + // printf("Propagate Outputs: %s, am i alive? %d\n", node->name().c_str(), !is_dead); + // printf("Frame: %s\n", input_frame->frame_name.c_str()); + + + if (!item->is_enter_exit_or_next_iter && !item->is_call_or_return) { // Fast path for node types that don't need special handling. // This is the case for most nodes. DCHECK_EQ(input_frame, output_frame); @@ -131,6 +143,35 @@ void PropagatorState::PropagateOutputs(const TaggedNode& tagged_node, /*decrement_activation=*/0); is_frame_done = input_frame->DecrementOutstandingOps(input_iter, ready); } + } else if (item->is_call) { + // if (is_dead) { + // // Stop the deadness propagation. + // output_frame = nullptr; + // } else { + FindOrCreateChildFrame(input_frame, input_iter, *item, &output_frame); + output_iter = 0; + { + mutex_lock l(output_frame->mu); + int activated = output_frame->ActivateNodesLocked( + item, is_dead, output_iter, outputs, ready); + output_frame->AdjustOutstandingOpsLocked(output_iter, activated, ready); + output_frame->num_pending_inputs--; + } + is_frame_done = input_frame->DecrementOutstandingOps(input_iter, ready); + } else if (item->is_return) { + // if (is_dead) { + // // Stop the deadness propagation. + // output_frame = nullptr; + // } else { + output_frame = input_frame->parent_frame; + output_iter = input_frame->parent_iter; + { + mutex_lock l(output_frame->mu); + int activated = output_frame->ActivateNodesLocked( + item, is_dead, output_iter, outputs, ready); + output_frame->AdjustOutstandingOpsLocked(output_iter, activated, ready); + } + is_frame_done = input_frame->DecrementOutstandingOps(input_iter, ready); } else { DCHECK(item->is_next_iteration); if (is_dead) { @@ -244,11 +285,11 @@ void PropagatorState::DumpIterationState(const FrameState* frame, void PropagatorState::DumpState() { mutex_lock l(mu_); LOG(WARNING) << "Dumping state"; - for (auto& frame : outstanding_frames_) { - LOG(WARNING) << frame.first; - FrameState* frame_state = frame.second; - frame_state->DumpIterationState(this); - } + // for (auto& frame : outstanding_frames_) { + // LOG(WARNING) << frame.first; + // FrameState* frame_state = frame.second; + // frame_state->DumpIterationState(this); + // } } void PropagatorState::FindOrCreateChildFrame(FrameState* frame, @@ -264,9 +305,9 @@ void PropagatorState::FindOrCreateChildFrame(FrameState* frame, Hash64Combine(iter_state->iter_num, Hash64(frame_info.name))); { - tf_shared_lock executor_lock(mu_); - auto it = outstanding_frames_.find(child_id); - if (it != outstanding_frames_.end()) { + tf_shared_lock executor_lock(frame->mu); + auto it = frame->outstanding_child_frames_.find(child_id); + if (it != frame->outstanding_child_frames_.end()) { *child = it->second; return; } @@ -285,6 +326,7 @@ void PropagatorState::FindOrCreateChildFrame(FrameState* frame, temp->frame_id = child_id; temp->parent_frame = frame; temp->parent_iter = iter_state; + temp->call_id = node_item.call_id; temp->InitializeFrameInfo(frame_info); // Initialize iteration 0. @@ -295,14 +337,13 @@ void PropagatorState::FindOrCreateChildFrame(FrameState* frame, } { - mutex_lock executor_lock(mu_); - auto it = outstanding_frames_.find(child_id); - if (it != outstanding_frames_.end()) { + mutex_lock executor_lock(frame->mu); + auto it = frame->outstanding_child_frames_.find(child_id); + if (it != frame->outstanding_child_frames_.end()) { *child = it->second; } else { - mutex_lock frame_lock(frame->mu); iter_state->outstanding_frame_count++; - outstanding_frames_[child_id] = temp; + frame->outstanding_child_frames_[child_id] = temp; *child = temp; temp = nullptr; } @@ -382,8 +423,10 @@ void PropagatorState::DeleteFrame(FrameState* frame, TaggedNodeSeq* ready) { // Delete the frame. if (vlog_) VLOG(2) << "Delete frame " << frame->frame_id; { - mutex_lock executor_lock(mu_); - outstanding_frames_.erase(frame->frame_id); + if (parent_frame != nullptr) { + mutex_lock parent_frame_lock(parent_frame->mu); + parent_frame->outstanding_child_frames_.erase(frame->frame_id); + } } delete frame; } @@ -551,6 +594,12 @@ int PropagatorState::FrameState::ActivateNodesSlowPathInternal( dst_ready = (adjust_result.pending_count == 1) && dst_dead; } } else { + if (dst_item->is_return) { + // In case of "Return" dst_node, + // we compare node's frame attr with current frame name + // if they are different, ignore this op + if (dst_item->call_id != call_id) continue; + } // Handle all other (non-merge) nodes. // We need to set the input of the op before adjusting activation. @@ -572,6 +621,23 @@ int PropagatorState::FrameState::ActivateNodesSlowPathInternal( increment_dead); dst_dead = adjust_result.dead_count > 0; dst_ready = !(adjust_result.pending_count > 0); + + if (dst_item->is_return && increment_dead) { + // The only dead input a Return op will ever may get + // is the control input propagated to it from a corresponding + // dead Call op in case of untaken branch. So at this point + // we are certain that Return op will never receive another input. + // Therefore, we force it to be added in queue for the sake of + // deadness propagation and we adjust it for activation once more, + // so that it no longer waits for another (never coming) input. + const PendingCounts::AdjustResult adjust_result = + atomic ? iter_state->adjust_for_activation_atomic(dst_pending_id, + increment_dead) + : iter_state->adjust_for_activation(dst_pending_id, + increment_dead); + dst_dead = adjust_result.dead_count > 0; + dst_ready = !(adjust_result.pending_count > 0); + } } maybe_add_to_ready(dst_id, dst_item, dst_ready, dst_dead); diff --git a/tensorflow/core/common_runtime/propagator_state.h b/tensorflow/core/common_runtime/propagator_state.h index 680cb13ef3ecb4..0b5d2329ac49df 100644 --- a/tensorflow/core/common_runtime/propagator_state.h +++ b/tensorflow/core/common_runtime/propagator_state.h @@ -261,6 +261,8 @@ class PropagatorState { // frame_name. uint64 frame_id; + int call_id = -1; + // The iteration state of its parent frame when this frame is created. // nullptr if there is no parent frame. The frame_name/parent_iter pair // uniquely identifies this FrameState. @@ -281,6 +283,15 @@ class PropagatorState { // The number of outstanding iterations. int num_outstanding_iterations TF_GUARDED_BY(mu) = 1; + // Mapping from frame ID to outstanding frames. A new frame is created + // at some iteration of an active frame. So the unique key for the new + // child frame is a hash composed of the ID of the parent frame, the iteration + // number at which the parent frame is creating the new frame, and the + // name of the new frame from nodedef. + absl::flat_hash_map outstanding_child_frames_ + TF_GUARDED_BY(mu); + + private: // The active iteration states of this frame. gtl::InlinedVector iterations; @@ -538,13 +549,6 @@ class PropagatorState { // The root frame in which the execution of this step is started. FrameState* root_frame_; - // Mapping from frame ID to outstanding frames. A new frame is created - // at some iteration of an active frame. So the unique key for the new - // child frame is a hash composed of the ID of the parent frame, the iteration - // number at which the parent frame is creating the new frame, and the - // name of the new frame from nodedef. - absl::flat_hash_map outstanding_frames_ - TF_GUARDED_BY(mu_); PropagatorState(const PropagatorState&) = delete; void operator=(const PropagatorState&) = delete; diff --git a/tensorflow/core/graph/graph.cc b/tensorflow/core/graph/graph.cc index 10984ae23608bc..6c84e6f2b291ce 100644 --- a/tensorflow/core/graph/graph.cc +++ b/tensorflow/core/graph/graph.cc @@ -61,6 +61,8 @@ Node::NodeClass Node::GetNodeClassForOp(const std::string& ts) { REF_CLASS("Enter", NC_ENTER), REF_CLASS("Exit", NC_EXIT), REF_CLASS("NextIteration", NC_NEXT_ITERATION), + REF_CLASS("Call", NC_CALL), + REF_CLASS("Return", NC_RETURN), {"LoopCond", NC_LOOP_COND}, {"ControlTrigger", NC_CONTROL_TRIGGER}, {"_Send", NC_SEND}, diff --git a/tensorflow/core/graph/graph.h b/tensorflow/core/graph/graph.h index c7a4f696bf126d..60927b1543fcab 100644 --- a/tensorflow/core/graph/graph.h +++ b/tensorflow/core/graph/graph.h @@ -167,6 +167,8 @@ class Node { bool IsEnter() const { return class_ == NC_ENTER; } bool IsExit() const { return class_ == NC_EXIT; } bool IsNextIteration() const { return class_ == NC_NEXT_ITERATION; } + bool IsCall() const { return class_ == NC_CALL; } + bool IsReturn() const { return class_ == NC_RETURN; } bool IsLoopCond() const { return class_ == NC_LOOP_COND; } bool IsControlTrigger() const { return class_ == NC_CONTROL_TRIGGER; } bool IsSend() const { return class_ == NC_SEND || class_ == NC_HOST_SEND; } @@ -182,6 +184,7 @@ class Node { bool IsControlFlow() const { return (class_ != NC_OTHER) && // Fast path (IsSwitch() || IsMerge() || IsEnter() || IsExit() || + IsCall() || IsReturn() || IsNextIteration()); } bool IsHostSend() const { return class_ == NC_HOST_SEND; } @@ -313,6 +316,8 @@ class Node { NC_ENTER, NC_EXIT, NC_NEXT_ITERATION, + NC_CALL, + NC_RETURN, NC_LOOP_COND, NC_CONTROL_TRIGGER, NC_SEND, @@ -935,6 +940,8 @@ inline bool IsMerge(const Node* node) { return node->IsMerge(); } inline bool IsEnter(const Node* node) { return node->IsEnter(); } inline bool IsExit(const Node* node) { return node->IsExit(); } inline bool IsNextIteration(const Node* n) { return n->IsNextIteration(); } +inline bool IsCall(const Node* node) { return node->IsCall(); } +inline bool IsReturn(const Node* node) { return node->IsReturn(); } inline bool IsLoopCond(const Node* node) { return node->IsLoopCond(); } inline bool IsControlTrigger(const Node* n) { return n->IsControlTrigger(); } inline bool IsSend(const Node* node) { return node->IsSend(); } diff --git a/tensorflow/core/grappler/function_transformation.h b/tensorflow/core/grappler/function_transformation.h new file mode 100644 index 00000000000000..2136ac5ccf7236 --- /dev/null +++ b/tensorflow/core/grappler/function_transformation.h @@ -0,0 +1,41 @@ +/* Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_GRAPPLER_OPTIMIZERS_FUNCTION_TRANSFORMATION_H_ +#define TENSORFLOW_GRAPPLER_OPTIMIZERS_FUNCTION_TRANSFORMATION_H_ + +#include "tensorflow/core/grappler/optimizers/graph_optimizer.h" +#include "tensorflow/core/grappler/grappler_item.h" + +namespace tensorflow { +namespace grappler { + + +// Replace function calling nodes with pairs of new 'Call/Return' operators +class FunctionTransformation : public GraphOptimizer { + public: + explicit FunctionTransformation() {} + ~FunctionTransformation() override = default; + + string name() const override { return "function_transformation"; }; + + bool UsesFunctionLibrary() const override { return true; } + + Status Optimize(Cluster* cluster, const GrapplerItem& item, + GraphDef* optimized_graph) override; +}; + +} // end namespace grappler +} // end namespace tensorflow + +#endif // TENSORFLOW_GRAPPLER_OPTIMIZERS_FUNCTION_TRANSFORMATION_H_ diff --git a/tensorflow/core/grappler/op_types.cc b/tensorflow/core/grappler/op_types.cc index e0981fe90c8ae9..46b04a24f41461 100644 --- a/tensorflow/core/grappler/op_types.cc +++ b/tensorflow/core/grappler/op_types.cc @@ -150,6 +150,11 @@ bool IsComplex(const NodeDef& node) { return node.op() == "Complex"; } bool IsComplexAbs(const NodeDef& node) { return node.op() == "ComplexAbs"; } +bool IsCall(const NodeDef& node) { + const auto& op = node.op(); + return op == "Call" || op == "RefCall"; +} + bool IsConcat(const NodeDef& node) { return node.op() == "Concat" || node.op() == "ConcatV2"; } @@ -498,6 +503,11 @@ bool IsRetval(const NodeDef& node) { return node.op() == "_Retval" || node.op() == "_DeviceRetval"; } +bool IsReturn(const NodeDef& node) { + const auto& op = node.op(); + return op == "Return" || op == "RefReturn"; +} + bool IsReverse(const NodeDef& node) { return node.op() == "Reverse" || node.op() == "ReverseV2"; } @@ -777,7 +787,7 @@ bool ModifiesInputsInPlace(const NodeDef& node) { } bool ModifiesFrameInfo(const NodeDef& node) { - return IsEnter(node) || IsExit(node) || IsNextIteration(node); + return IsEnter(node) || IsExit(node) || IsNextIteration(node) || IsCall(node) || IsReturn(node); } #define OPDEF_PROPERTY_HELPER(PROPERTY_CAP, PROPERTY) \ diff --git a/tensorflow/core/grappler/op_types.h b/tensorflow/core/grappler/op_types.h index c233b6e9c6b61a..9546e948136177 100644 --- a/tensorflow/core/grappler/op_types.h +++ b/tensorflow/core/grappler/op_types.h @@ -55,6 +55,7 @@ bool IsCheckNumerics(const NodeDef& node); bool IsCollective(const NodeDef& node); bool IsComplex(const NodeDef& node); bool IsComplexAbs(const NodeDef& node); +bool IsCall(const NodeDef& node); bool IsConcat(const NodeDef& node); bool IsConcatOffset(const NodeDef& node); bool IsConj(const NodeDef& node); @@ -163,6 +164,7 @@ bool IsRsqrt(const NodeDef& node); bool IsRsqrtGrad(const NodeDef& node); bool IsSelect(const NodeDef& node); bool IsSeluGrad(const NodeDef& node); +bool IsReturn(const NodeDef& node); bool IsSend(const NodeDef& node); bool IsShape(const NodeDef& node); bool IsShapeN(const NodeDef& node); diff --git a/tensorflow/core/grappler/optimizers/BUILD b/tensorflow/core/grappler/optimizers/BUILD index e967c46836756d..1881b2496b9913 100644 --- a/tensorflow/core/grappler/optimizers/BUILD +++ b/tensorflow/core/grappler/optimizers/BUILD @@ -445,6 +445,47 @@ tf_cuda_cc_test( ], ) +cc_library( + name = "function_transformation", + srcs = ["function_transformation.cc"], + hdrs = [ + "function_transformation.h", + ], + visibility = ["//visibility:public"], + deps = [ + ":graph_optimizer", + ":function_optimizer", + "//tensorflow/core:core_cpu_base", + "//tensorflow/core:framework", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core/grappler:grappler_item", + "//tensorflow/core/grappler:op_types", + "//tensorflow/core/grappler:utils", + "//tensorflow/core/grappler/utils:functions", + ], +) + +tf_cc_test( + name = "function_transformation_test", + srcs = ["function_transformation_test.cc"], + shard_count = 5, + deps = [ + ":function_transformation", + "//tensorflow/cc:cc_ops", + "//tensorflow/core:all_kernels", + "//tensorflow/core:core_cpu", + "//tensorflow/core:core_cpu_internal", + "//tensorflow/core:direct_session", + "//tensorflow/core:lib", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + "//tensorflow/core:testlib", + "//tensorflow/core/grappler:grappler_item", + "//tensorflow/core/grappler:utils", + ], +) + cc_library( name = "model_pruner", srcs = ["model_pruner.cc"], diff --git a/tensorflow/core/grappler/optimizers/function_transformation.cc b/tensorflow/core/grappler/optimizers/function_transformation.cc new file mode 100644 index 00000000000000..7963083e7a384f --- /dev/null +++ b/tensorflow/core/grappler/optimizers/function_transformation.cc @@ -0,0 +1,650 @@ +/* Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/grappler/optimizers/function_transformation.h" +#include +#include +#include +#include "tensorflow/core/util/event.pb.h" +#include "tensorflow/core/util/events_writer.h" + +#include "tensorflow/core/common_runtime/graph_constructor.h" +#include "tensorflow/core/graph/tensor_id.h" +#include "tensorflow/core/grappler/grappler_item.h" +#include "tensorflow/core/grappler/op_types.h" +#include "tensorflow/core/grappler/utils.h" +#include "tensorflow/core/grappler/utils/functions.h" +#include "tensorflow/core/framework/function.h" +#include "tensorflow/core/framework/function.pb.h" +#include "tensorflow/core/framework/graph_def_util.h" +#include "tensorflow/core/framework/node_def.pb.h" +#include "tensorflow/core/framework/node_def_util.h" +#include "tensorflow/core/framework/op_def.pb.h" +#include "tensorflow/core/framework/versions.pb.h" + +namespace tensorflow { +namespace grappler { +namespace { + +static constexpr const char* const kCallOp = "Call"; +static constexpr const char* const kRetOp = "Return"; +static constexpr const char* const kIdentityOp = "Identity"; +static constexpr const char* const kIdentityNOp = "IdentityN"; +static constexpr const char* const kMergeOp = "Merge"; + +struct FuncInfo { + gtl::ArraySlice fetch; + std::vector inputs; + std::vector input_def; + std::vector outputs; + std::vector output_def; +}; + +// same with commit b691c0 (possibly) +class FunctionInliningContext { + public: + explicit FunctionInliningContext(const GrapplerItem& item) + : item_(&item), function_library_(OpRegistry::Global(), item.graph.library()), functions_(InliningCandidates(item)) {} + + const FunctionLibraryDefinition& Library() const { return function_library_; } + + bool HasInlinedFunctions() const { return !functions_.empty(); } + + // Find inlining candidate by name. Return nullptr if not found. + const FunctionDef* FindInlinedFunction(const string& name) const { + auto it = functions_.find(name); + if (it != functions_.end()) { + return it->second; + } else { + return nullptr; + } + } + + const int graph_version() const { + return item_->graph.versions().producer(); + } + + private: + std::unordered_map InliningCandidates(const GrapplerItem& item) const { + std::unordered_map functions; + for (const FunctionDef& func : item.graph.library().function()) { + // Don't inline functions marked as noinline + // if (func.attr().count("_noinline") != 0) { + // continue; + // } + // Don't touch anything marked XLA to prevent XLA failures further down + // the road. + if (func.attr().count("_XlaCompile") > 0 && + func.attr().at("_XlaCompile").b()) { + continue; + } + // Can't create IdentityN nodes with no input or output: skip these + // functions for now. + if (func.signature().input_arg_size() == 0 || + func.signature().output_arg_size() == 0) { + continue; + } + functions[func.signature().name()] = &func; + } + return functions; + } + + FunctionLibraryDefinition function_library_; + std::unordered_map functions_; + const GrapplerItem* item_; + TF_DISALLOW_COPY_AND_ASSIGN(FunctionInliningContext); +}; + + +constexpr const char* const kFuncAttr = FunctionLibraryDefinition::kFuncAttr; + +// There are two ways of calling a Tensorflow function: +// +// 1. Direct function call: node.op() is the name of the function. +// +// 2. Indirect function call: the function name is passed through a node +// attribute, and special Tensorflow kernels are responsible for calling the +// function through the FunctionLibraryRuntime. Example: PartitionedCallOp. + +// Check if func_node.op() matches the name in FunctionDef signature. +bool IsDirectFunctionCall(const FunctionDef& func, const NodeDef& func_node) { + return func_node.op() == func.signature().name(); +} + +// Check if func_node has function attribute with a function name matching +// FunctionDef signature. +bool IsIndirectFunctionCall(const FunctionDef& func, const NodeDef& func_node) { + if (!IsPartitionedCall(func_node) && !IsStatefulPartitionedCall(func_node)) { + return false; + } + + auto* func_attr = AttrSlice(func_node).Find(kFuncAttr); + return func_attr != nullptr && func_attr->has_func() && + func_attr->func().name() == func.signature().name(); +} + +AttrSlice FunctionInstantiationAttributes(const FunctionDef& func, + const NodeDef& func_node) { + if (IsDirectFunctionCall(func, func_node)) { + return AttrSlice(func_node); + + } else if (IsIndirectFunctionCall(func, func_node)) { + auto* func_attr = AttrSlice(func_node).Find(kFuncAttr); + return AttrSlice(&func_attr->func().attr()); + + } else { + LOG(WARNING) << "Can't resolve function instantiation attributes: " + << SummarizeNodeDef(func_node); + return AttrSlice(); + } +} + +// Copy input/output argument type to the type_list. Return error if argument +// type is not explicitly defined, and not specified in function attributes. +Status CopyArgType(const OpDef::ArgDef& arg, + const AttrSlice& func_attr, + DataType* type) { + if (arg.type() != DT_INVALID) { + *type = arg.type(); + } else { + const AttrValue* it = func_attr.Find(arg.type_attr()); + if (it == nullptr || it->type() == DT_INVALID) { + return errors::InvalidArgument( + "Invalid argument ", arg.name()); + } + *type = it->type(); + } + return Status::OK(); +} + +// Copy input/output argument type to the type_list. Return error if argument +// type is not explicitly defined, and not specified in function attributes. +Status CopyArgType(const OpDef::ArgDef& arg, + const AttrSlice& func_attr, + AttrValue::ListValue* type_list) { + if (arg.type() != DT_INVALID) { + type_list->add_type(arg.type()); + } else { + const AttrValue* it = func_attr.Find(arg.type_attr()); + if (it == nullptr || it->type() == DT_INVALID) { + return errors::InvalidArgument("Invalid argument ", arg.name()); + } + type_list->add_type(it->type()); + } + return Status::OK(); +} + +struct CallInfo { + int call_id; + NodeDef* node; + string node_name; + string function_name; + string device; + std::vector input_nodes; + AttrSlice attr; +}; + +class CallRewriter { + + public: + explicit CallRewriter(const GrapplerItem item_, GraphDef* graph_, const FunctionInliningContext& ctx_) + : graph(graph_), ctx(ctx_), item(item_) { } + + ~CallRewriter() { + Flush(); + } + + Status CollectCalls(std::vector& calls); + + Status TransformCall(CallInfo& call_info); + + // Inlines a function to item.graph and if already inlined provide func_info + Status FindCompatibleOrInlineFunction(const string& name, + const AttrSlice& func_attr, + const string& device, + GraphDef* optimized_graph, FuncInfo& func_info); + + void Flush() { + if (!nodes_to_delete.empty()) { + // garbage collect the transformed call nodes + int last = graph->node_size() - 1; + for (int i = graph->node_size() - 1; i >= 0; --i) { + const NodeDef& node = graph->node(i); + if (nodes_to_delete.find(node.name()) != nodes_to_delete.end()) { + graph->mutable_node()->SwapElements(i,last); + last--; + } + } + + graph->mutable_node()->DeleteSubrange(last + 1, + graph->node_size() - last - 1); + + nodes_to_delete.clear(); + } + + if (!output_map_.empty()) { + // change all the recorded outputs; + // the new outputs where produced by the addition of the RetOp and + // the substitution was deferred to increase performance + for (NodeDef& node : *graph->mutable_node()) { + for (string& in : *node.mutable_input()) { + auto it = output_map_.find(in); + if (it != output_map_.end()) { + in = it->second; + } + } + } + output_map_.clear(); + } + } + + inline int GetCallId(const NodeDef& node) { int call_id = id; id++; return call_id; } + + private: + Status AddCallOp(const CallInfo& call_info, const OpDef::ArgDef arg, + const string& input, int arg_id, NodeDef* call_node); + + Status AddRetOp(const CallInfo& call_info, const OpDef::ArgDef arg, + const string& input, int arg_id, NodeDef* ret_node); + + Status ConnectInput(NodeDef* from, NodeDef* to); + + bool ShouldPreserveOutputs(const string& node) { + for (const string& fetch_out : item.fetch) { + if (NodeName(fetch_out) == node) + return true; + } + return false; + } + + void ReplaceOutput(const string& old_output, const string& new_output) { + // maybe some more checks + output_map_[old_output] = new_output; + } + + void MarkCallTransformed(CallInfo& call_info) { + NodeDef* node = call_info.node; + node->clear_input(); + node->set_op("NoOp"); + node->set_name(AddPrefixToNodeName(node->name(), "$MarkToDelete$")); + nodes_to_delete.insert(node->name()); + } + + GraphDef* graph; + const FunctionInliningContext& ctx; + const GrapplerItem item; + std::unordered_map transformed_functions_; + std::unordered_map output_map_; + std::set nodes_to_delete; + int id = 0; + + TF_DISALLOW_COPY_AND_ASSIGN(CallRewriter); +}; + + +Status CallRewriter::CollectCalls(std::vector& calls) { + + // identify and collect calls in the graph + for (NodeDef& node : *graph->mutable_node()) { + const FunctionDef* func = ctx.FindInlinedFunction(node.op()); + if (func != nullptr) { + CallInfo call; + call.call_id = GetCallId(node); + call.node_name = node.name(); + call.function_name = node.op(); + call.node = &node; + call.device = node.device(); + call.attr = FunctionInstantiationAttributes(*func, node); + + int input_size = func->signature().input_arg_size(); + call.input_nodes.resize(input_size); + for (int i = 0; i < input_size; i++) { + call.input_nodes[i] = node.input(i); + } + calls.push_back(call); + } + } + return Status::OK(); +} + +Status CallRewriter::AddCallOp(const CallInfo& call_info, + const OpDef::ArgDef arg, + const string& input, + int arg_id, NodeDef* call) { + string prefix = call_info.node_name; + string call_name = strings::StrCat("Call", "_", arg_id); + call->set_op(kCallOp); + call->set_name(AddPrefixToNodeName(call_name, prefix)); + //call->set_device(node.device()); + call->add_input(input); + + DataType type; + TF_RETURN_IF_ERROR(CopyArgType(arg, call_info.attr, &type)); + + auto& attr = *call->mutable_attr(); + + //SetArgType(arg, call_info.attr, attr); + + attr["T"].set_type(type); + attr["frame_name"].set_s(call_info.function_name); + attr["call_id"].set_i(call_info.call_id); + attr["arg_id"].set_i(arg_id); + attr["is_constant"].set_b(false); + + return Status::OK(); +} + +Status CallRewriter::AddRetOp(const CallInfo& call_info, + const OpDef::ArgDef arg, + const string& input, + int arg_id, NodeDef* ret) { + string prefix = call_info.node_name; + string ret_name = strings::StrCat("Ret", "_", arg_id); + ret->set_op(kRetOp); + ret->set_name(AddPrefixToNodeName(ret_name, prefix)); + ret->add_input(input); + + DataType type; + TF_RETURN_IF_ERROR(CopyArgType(arg, call_info.attr, &type)); + + auto& attr = *ret->mutable_attr(); + attr["T"].set_type(type); + attr["frame_name"].set_s(call_info.function_name); + attr["call_id"].set_i(call_info.call_id); + attr["arg_id"].set_i(arg_id); + + return Status::OK(); +} + +Status CallRewriter::ConnectInput(NodeDef* from, NodeDef* to) { + int to_input = to->input_size(); + if (to_input == 1) { + // it is Identity and we convert it to Merge. + CHECK(IsIdentity(*to)); + to->set_op(kMergeOp); + } + to->add_input(from->name()); + if (to->input_size() > 1) { + (*to->mutable_attr())["N"].set_i(to->input_size()); + } + return Status::OK(); +} + +Status CallRewriter::TransformCall(CallInfo& call_info) { + FuncInfo func_info; + + // inlines the body of a function and provides a struct with func_info + TF_RETURN_IF_ERROR(FindCompatibleOrInlineFunction( + call_info.function_name, call_info.attr, call_info.device, graph, func_info)); + + CHECK_EQ(call_info.input_nodes.size(), func_info.inputs.size()); + + std::vector call_nodes; + std::vector ret_nodes; + + call_nodes.resize(func_info.inputs.size()); + for (unsigned int arg_num = 0; arg_num < func_info.inputs.size(); arg_num++) { + call_nodes[arg_num] = graph->add_node(); + AddCallOp(call_info, + func_info.input_def[arg_num], + call_info.input_nodes[arg_num], + arg_num, + call_nodes[arg_num]); + + call_nodes[arg_num]->set_device(call_info.device); + + // connect the input of the inlined function to feed from call. + TF_RETURN_IF_ERROR(ConnectInput(call_nodes[arg_num], func_info.inputs[arg_num])); + } + + ret_nodes.resize(func_info.outputs.size()); + for (unsigned int out_port = 0; out_port < func_info.outputs.size(); out_port++) { + ret_nodes[out_port] = graph->add_node(); + AddRetOp(call_info, + func_info.output_def[out_port], + func_info.outputs[out_port], + out_port, + ret_nodes[out_port]); + ret_nodes[out_port]->set_device(call_info.device); + } + + // for each call create a control dependency to each return + // to facilitate dead propagation semantics + for (NodeDef* ret : ret_nodes) { + for (NodeDef* call : call_nodes) + *(ret->add_input()) = AsControlDependency(call->name()); + } + + if (ShouldPreserveOutputs(call_info.node_name)) { + // create an IdentityN with the same name of the initial function call + // so as to preserve the naming of the outputs. + // we re-use the initial node and we change (a) the op to IdentityN and + // (b) the inputs to point to the outputs of the ret_nodes + // The other information such as types, device placement etc remain the same. + // The IdentityN node will sync the outputs and therefore may result to performance degradation. + NodeDef* out = graph->add_node(); + out->set_op(kIdentityNOp); + out->set_name(call_info.node_name); + out->set_device(call_info.device); + AttrValue::ListValue* type_list = (*out->mutable_attr())["T"].mutable_list(); + for (const OpDef::ArgDef& arg : func_info.output_def) { + TF_RETURN_IF_ERROR(CopyArgType(arg, call_info.attr, type_list)); + } + for (unsigned int i = 0; i < func_info.outputs.size(); i++) { + *out->add_input() = ret_nodes[i]->name(); + } + } else { + for (unsigned int out_port = 0; out_port < func_info.outputs.size(); out_port++) { + ReplaceOutput(strings::StrCat(call_info.node_name, ":", out_port), ret_nodes[out_port]->name()); + } + if (func_info.outputs.size() == 1) { + ReplaceOutput(call_info.node_name, ret_nodes[0]->name()); + } + } + printf("Mark call %s (function %s) as transformed\n", call_info.node_name.c_str(), call_info.function_name.c_str()); + MarkCallTransformed(call_info); + + return Status::OK(); +} + +Status InlineFunction(const FunctionDef& func_def, + const FunctionInliningContext& ctx, + const AttrSlice& func_attr, + const string& device, + GraphDef* graph, FuncInfo& func_info) { + // std::unique_ptr item = GrapplerItemFromFunctionDef(func_def, func_attr, ctx.Library()); + GrapplerFunctionItem fitem; + GrapplerFunctionItem* item = &fitem; + TF_RETURN_IF_ERROR(MakeGrapplerFunctionItem( + func_def, func_attr, ctx.Library(), ctx.graph_version(), item)); + + string prefix = func_def.signature().name(); + + if (!item) { + return errors::InvalidArgument( + "Failed to inline function ", func_def.signature().name()); + } + int arg_size = func_def.signature().input_arg_size(); + // create an inverse map of arg to provide name -> argument number + std::unordered_map input_nodes; + for (int i = 0; i < arg_size; ++i) { + const OpDef::ArgDef& arg = func_def.signature().input_arg(i); + input_nodes[arg.name()] = i; + } + func_info.inputs.resize(arg_size); + func_info.input_def.resize(arg_size); + for (int i = 0; i < arg_size; ++i) { + const OpDef::ArgDef& arg = func_def.signature().input_arg(i); + NodeDef* merge = graph->add_node(); + merge->set_name(AddPrefixToNodeName(strings::StrCat("Input", "_", i), prefix)); + merge->set_op(kIdentityOp); + merge->set_device(device); + + DataType type; + TF_RETURN_IF_ERROR(CopyArgType(arg, func_attr, &type)); + auto& attr = *merge->mutable_attr(); + attr["T"].set_type(type); + + func_info.inputs[i] = merge; + func_info.input_def[i] = arg; + } + + // prefix each node in function graph and place it to the global graph. + // the inputs of each node need to be renamed as well to reflect the change. + for (NodeDef& func_body_node : *item->graph.mutable_node()) { + const string& curr_name = func_body_node.name(); + // If the func body node is func's input argument + auto input_it = input_nodes.find(curr_name); + + if (input_it != input_nodes.end()) { + CHECK_EQ(0, func_body_node.input_size()); + // Turn input placeholders into identity nodes + if (IsPlaceholder(func_body_node)) { + func_body_node.set_op(kIdentityOp); + } + // Connect merge with input arg + func_body_node.add_input(func_info.inputs[input_it->second]->name()); + } else { + // Else if not an input_arg_node + // Update the input names if any. + for (string& input : *func_body_node.mutable_input()) { + input = AddPrefixToNodeName(input, prefix); + } + // If the node has no input, make hook it up to the Merge nodes to ensure + // it runs in the same frame as the other nodes of the function body. + if (func_body_node.input_size() == 0) { + for (auto& func_input_node : func_info.inputs) { + *func_body_node.add_input() = AsControlDependency(func_input_node->name()); + } + } + } + + // Add the node name as a prefix to avoid collisions after inlining + func_body_node.set_name(AddPrefixToNodeName(curr_name, prefix)); + + // Make sure the node is placed + if (func_body_node.device().empty()) + func_body_node.set_device(device); + + // Move the node to the main graph + graph->add_node()->Swap(&func_body_node); + } + + func_info.outputs.clear(); + func_info.outputs.resize(item->fetch.size()); + func_info.output_def.resize(item->fetch.size()); + + for (unsigned int i = 0; i < item->fetch.size(); i++) { + func_info.outputs[i] = AddPrefixToNodeName(item->fetch[i], prefix); + func_info.output_def[i] = func_def.signature().output_arg(i); + } + + return Status::OK(); +} + +// new +Status CallRewriter::FindCompatibleOrInlineFunction( + const string& func_name, + const AttrSlice& func_attr, + const string& device, + GraphDef* graph, + FuncInfo& func_info) { + const auto& it = transformed_functions_.find(func_name); + // maybe it is not wise to discard call attributes + // possible type specialization? + if (it != transformed_functions_.end()) { + func_info = it->second; + return Status::OK(); + } + const FunctionDef* func_def = ctx.FindInlinedFunction(func_name); + if (func_def == nullptr) { + return errors::InvalidArgument( + "Invalid argument, function ", func_name, "can not be found", + "or not marked to be inlined"); + } + TF_RETURN_IF_ERROR( + InlineFunction(*func_def, ctx, func_attr, device, graph, func_info)); + transformed_functions_[func_name] = func_info; + printf("Store inlined function %s\n", func_name.c_str()); + return Status::OK(); +} + +} // namespace + +Status FunctionTransformation::Optimize(Cluster* cluster, const GrapplerItem& item, + GraphDef* output) { + FunctionInliningContext ctx(item); + CallRewriter call_rewriter(item, output, ctx); + + *output = item.graph; + if (!ctx.HasInlinedFunctions()) { + return Status::OK(); + } + + std::vector calls; + while (1) { + TF_RETURN_IF_ERROR(call_rewriter.CollectCalls(calls)); + if (calls.empty()) { + break; + } + for (CallInfo& call : calls) { + Status s = call_rewriter.TransformCall(call); + if (!s.ok()) { + printf("Error: %s\n", s.error_message().c_str()); + return s; + } + printf("After transforming call %s:\n %s\n", call.function_name.c_str(), SummarizeGraphDef(*output).c_str()); + } + calls.clear(); + call_rewriter.Flush(); + } + call_rewriter.Flush(); + printf("After finalizing:\n %s\n", SummarizeGraphDef(*output).c_str()); + *output->mutable_versions() = item.graph.versions(); + + // Function Library should be pruned of unreachable function definitions + // cf. https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/grappler/optimizers/function_optimizer.cc#L428 + // however in this version there is a check in meta_optimizer that guarantees + // that function library remains of the same length + // cf. https://github.com/acharal/tensorflow/blob/r1.4_recursion/tensorflow/core/grappler/optimizers/meta_optimizer.cc#L132 + *output->mutable_library() = item.graph.library(); + + + + /******************************************************************************************************/ + // Dumps optimized graph in a not so readable form + // const GraphDef* tmp = optimized_graph; + // printf("Summarize Optimized Graph\n %s\n", SummarizeGraphDef(*tmp).c_str()); + // Write an event, so that we can visualize this optimized graph in tensorboard + EventsWriter writer("TRANSFORMATION"); + Event event; + event.set_wall_time(1234); + event.set_step(34); + const size_t proto_size = output->ByteSizeLong(); + void* buf = port::Malloc(proto_size); + if (buf == nullptr) { + return errors::ResourceExhausted( + "Failed to allocate memory to serialize message of type '" , + output->GetTypeName(), "' and size ", proto_size); + } + output->SerializeToArray(buf, proto_size); + const void* bf = buf; + event.set_graph_def(bf, proto_size); + writer.WriteEvent(event); + /******************************************************************************************************/ + + return Status::OK(); +} + +} // end namespace grappler +} // end namespace tensorflow diff --git a/tensorflow/core/grappler/optimizers/function_transformation_test.cc b/tensorflow/core/grappler/optimizers/function_transformation_test.cc new file mode 100644 index 00000000000000..751278545984e2 --- /dev/null +++ b/tensorflow/core/grappler/optimizers/function_transformation_test.cc @@ -0,0 +1,57 @@ +/* + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/grappler/optimizers/function_transformation.h" +#include "tensorflow/cc/ops/standard_ops.h" +#include "tensorflow/core/framework/node_def.pb.h" +#include "tensorflow/core/framework/tensor_testutil.h" +#include "tensorflow/core/grappler/grappler_item.h" +#include "tensorflow/core/grappler/utils.h" +#include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/lib/strings/strcat.h" +#include "tensorflow/core/platform/test.h" +#include "tensorflow/core/public/session.h" + +namespace tensorflow { +namespace grappler { +namespace { + + +class FunctionTransformationTest : public ::testing::Test { + +}; + +TEST_F(FunctionTransformationTest, NoTrans) { + + tensorflow::Scope s = tensorflow::Scope::NewRootScope(); + + Output a = ops::Const(s.WithOpName("a"), 1.0f, {1}); + Output b = ops::Const(s.WithOpName("b"), 2.0f, {1}); + Output c = ops::AddN(s.WithOpName("c").WithDevice("/CPU:0"), {a, b}); + Output d = ops::AddN(s.WithOpName("d"), {b, c}); + + GrapplerItem item; + item.fetch.push_back("d"); + TF_CHECK_OK(s.ToGraphDef(&item.graph)); + + FunctionTransformation func_trans; + GraphDef output; + Status status = func_trans.Optimize(nullptr, item, &output); + TF_EXPECT_OK(status); +} + +} // namespace +} // namespace grappler +} // namespace tensorflow diff --git a/tensorflow/core/grappler/optimizers/meta_optimizer.cc b/tensorflow/core/grappler/optimizers/meta_optimizer.cc index 84afab6e12badf..1da252ce7a51b2 100644 --- a/tensorflow/core/grappler/optimizers/meta_optimizer.cc +++ b/tensorflow/core/grappler/optimizers/meta_optimizer.cc @@ -45,6 +45,7 @@ limitations under the License. #include "tensorflow/core/grappler/optimizers/debug_stripper.h" #include "tensorflow/core/grappler/optimizers/dependency_optimizer.h" #include "tensorflow/core/grappler/optimizers/function_optimizer.h" +#include "tensorflow/core/grappler/optimizers/function_transformation.h" #include "tensorflow/core/grappler/optimizers/generic_layout_optimizer.h" #include "tensorflow/core/grappler/optimizers/implementation_selector.h" #include "tensorflow/core/grappler/optimizers/loop_optimizer.h" @@ -261,6 +262,10 @@ std::unique_ptr MetaOptimizer::MakeNewOptimizer( MK_OPT("pin_to_host", "pin_to_host_optimization", new PinToHostOptimizer(cfg_.pin_to_host_optimization())); + if (LowerControlFlow()) { + MK_OPT("function_transformation", "function_transformation", new FunctionTransformation()); + } + return std::unique_ptr(); } @@ -486,7 +491,14 @@ Status MetaOptimizer::InitializeOptimizers( optimizers->push_back( std::make_unique(cfg_.auto_parallel().num_replicas())); } - + if (BOTH_NOT_OFF(function_transformation)) { + if (USER_IS_EXPERIMENTAL_MLIR(function_transformation) || + USER_IS_EXPERIMENTAL_BOTH(function_transformation)) { + VLOG(2) << "function_transformation is not implemented in TFG yet"; + } else { + optimizers->push_back(MakeUnique()); + } + } #ifndef ENABLE_MKL if (BOTH_ARE_ON(scoped_allocator_optimization)) { optimizers->push_back(std::make_unique( @@ -641,6 +653,7 @@ void MetaOptimizer::PrintUserAndPluginConfigs( PRINT_CFG(loop_optimization) PRINT_CFG(dependency_optimization) PRINT_CFG(scoped_allocator_optimization) + PRINT_CFG(function_transformation) #undef PRINT_CFG user_cfg.toggle_config["auto_mixed_precision"] = AutoMixedPrecisionEnabled(cfg_.auto_mixed_precision()) @@ -696,6 +709,7 @@ void MetaOptimizer::PrintUserAndPluginConfigs( PRINT_CFG("memory", "memory_optimization") PRINT_CFG("autoparallel", "auto_parallel") PRINT_CFG("scoped_allocator", "scoped_allocator_optimization") + PRINT_CFG("function_transformation", "function_transformation") #undef PRINT_CFG } } @@ -1353,6 +1367,7 @@ bool MetaOptimizerEnabled(const ConfigProto& cfg) { rewrite_cfg.auto_parallel().enable() || rewrite_cfg.memory_optimization() != RewriterConfig::NO_MEM_OPT || rewrite_cfg.debug_stripper() == RewriterConfig::ON || + rewrite_cfg.function_transformation() != RewriterConfig::OFF || #ifndef ENABLE_MKL rewrite_cfg.scoped_allocator_optimization() == RewriterConfig::ON || #endif diff --git a/tensorflow/core/grappler/utils/topological_sort.cc b/tensorflow/core/grappler/utils/topological_sort.cc index 29e00240028715..5f4be7975449ce 100644 --- a/tensorflow/core/grappler/utils/topological_sort.cc +++ b/tensorflow/core/grappler/utils/topological_sort.cc @@ -56,6 +56,7 @@ Status ComputeTopologicalOrder( // Keep track of how many inputs are ready for the given node. std::vector num_ready_inputs(graph.node_size(), 0); + std::unordered_map> returning_nodes; // We'll push index of ready nodes to this output vector. ready_nodes->reserve(graph.node_size()); @@ -68,12 +69,47 @@ Status ComputeTopologicalOrder( ready_nodes->push_back(i); back++; } + bool recursion_merge = false; if (IsMerge(graph.node(i))) { for (int input : graph_view.GetFanin(i)) { if (IsNextIteration(graph.node(input))) { num_ready_inputs[i]++; } + else if (IsCall(graph.node(input))) { + num_ready_inputs[i]++; + recursion_merge = true; + } + } + if (recursion_merge) { + num_ready_inputs[i]--; + recursion_merge = false; } + } else if (IsReturn(graph.node(i))) { + // Nodes that send their output to "Return" nodes are + // function Returning Nodes and in case of recursive functions + // those nodes are part of graph cycles. + for (int input : graph_view.GetFanin(i)) { + // In order to detect the recursion cycles we depend on + // the fact that a recursive function's returning node, + // will be sending outputs to at least 2 "Return" nodes + // with different "call_id" attributes (same "call_id" + // attrs would mean that they belong in the same function call + // but they correspond to different function outputs) + // if (!StringPiece(graph.node(input)).starts_with("^")) { + if (true) { + int call_id; + GetNodeAttr(graph.node(i), "call_id", &call_id); + returning_nodes[input].emplace(call_id); + } + } + num_ready_inputs[i] = 0; + } + } + + for (const auto& retnode : returning_nodes) { + if (retnode.second.size() > 1) { + // Detected Cycle + num_ready_inputs[retnode.first]++; } } diff --git a/tensorflow/core/kernels/BUILD b/tensorflow/core/kernels/BUILD index c541597c8d9d80..ffd85a9c0766f8 100644 --- a/tensorflow/core/kernels/BUILD +++ b/tensorflow/core/kernels/BUILD @@ -2528,6 +2528,15 @@ tf_cc_test( ], ) +tf_kernel_library( + name = "function_control_ops", + prefix = "function_control_ops", + deps = [ + "//tensorflow/core:framework", + "//tensorflow/core:lib", + ], +) + cc_library( name = "data_flow", deps = [ diff --git a/tensorflow/core/kernels/function_control_ops.cc b/tensorflow/core/kernels/function_control_ops.cc new file mode 100644 index 00000000000000..a22c0791022613 --- /dev/null +++ b/tensorflow/core/kernels/function_control_ops.cc @@ -0,0 +1,191 @@ +/* Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/core/kernels/function_control_ops.h" + +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/register_types.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/platform/macros.h" + +namespace tensorflow { + +void CallOpe::Compute(OpKernelContext* context) { + if (IsRefType(context->input_dtype(0))) { + context->forward_ref_input_to_ref_output(0, 0); + } else { + context->set_output(0, context->input(0)); + } +} + +REGISTER_KERNEL_BUILDER(Name("Call").Device(DEVICE_CPU), CallOpe); +REGISTER_KERNEL_BUILDER(Name("RefCall").Device(DEVICE_CPU), CallOpe); + +#define REGISTER_GPU_KERNEL(type) \ + REGISTER_KERNEL_BUILDER( \ + Name("Call").Device(DEVICE_GPU).TypeConstraint("T"), CallOpe) +#define REGISTER_GPU_REF_KERNEL(type) \ + REGISTER_KERNEL_BUILDER( \ + Name("RefCall").Device(DEVICE_GPU).TypeConstraint("T"), CallOpe) + +TF_CALL_NUMBER_TYPES_NO_INT32(REGISTER_GPU_KERNEL); +TF_CALL_NUMBER_TYPES_NO_INT32(REGISTER_GPU_REF_KERNEL); +REGISTER_GPU_KERNEL(bool); +REGISTER_GPU_REF_KERNEL(bool); + +#undef REGISTER_GPU_KERNEL +#undef REGISTER_GPU_REF_KERNEL + +#ifdef TENSORFLOW_USE_SYCL +#define REGISTER_SYCL_KERNEL(type) \ + REGISTER_KERNEL_BUILDER( \ + Name("Call").Device(DEVICE_SYCL).TypeConstraint("T"), CallOpe) +REGISTER_SYCL_KERNEL(bool); +TF_CALL_NUMBER_TYPES_NO_INT32(REGISTER_SYCL_KERNEL); + +#define REGISTER_SYCL_REF_KERNEL(type) \ + REGISTER_KERNEL_BUILDER( \ + Name("RefCall").Device(DEVICE_SYCL).TypeConstraint("T"), CallOpe) +REGISTER_SYCL_REF_KERNEL(bool); +TF_CALL_NUMBER_TYPES_NO_INT32(REGISTER_SYCL_REF_KERNEL); + +#undef REGISTER_SYCL_KERNEL +#undef REGISTER_SYCL_REF_KERNEL +#define REGISTER_SYCL_HOST_KERNEL(type) \ + REGISTER_KERNEL_BUILDER(Name("Call") \ + .Device(DEVICE_SYCL) \ + .HostMemory("data") \ + .HostMemory("output") \ + .TypeConstraint("T"), \ + CallOpe) + +#define REGISTER_SYCL_HOST_REF_KERNEL(type) \ + REGISTER_KERNEL_BUILDER(Name("RefCall") \ + .Device(DEVICE_SYCL) \ + .HostMemory("data") \ + .HostMemory("output") \ + .TypeConstraint("T"), \ + CallOpe) + +REGISTER_SYCL_HOST_KERNEL(int32); +REGISTER_SYCL_HOST_REF_KERNEL(int32); +REGISTER_SYCL_HOST_KERNEL(string); +REGISTER_SYCL_HOST_REF_KERNEL(string); +REGISTER_SYCL_HOST_KERNEL(ResourceHandle); + +#undef REGISTER_SYCL_HOST_KERNEL +#undef REGISTER_SYCL_HOST_REF_KERNEL +#endif // TENSORFLOW_USE_SYCL + +#define REGISTER_GPU_HOST_KERNEL(type) \ + REGISTER_KERNEL_BUILDER(Name("Call") \ + .Device(DEVICE_GPU) \ + .HostMemory("data") \ + .HostMemory("output") \ + .TypeConstraint("T"), \ + CallOpe) + +#define REGISTER_GPU_HOST_REF_KERNEL(type) \ + REGISTER_KERNEL_BUILDER(Name("RefCall") \ + .Device(DEVICE_GPU) \ + .HostMemory("data") \ + .HostMemory("output") \ + .TypeConstraint("T"), \ + CallOpe) + +REGISTER_GPU_HOST_KERNEL(int32); +REGISTER_GPU_HOST_REF_KERNEL(int32); +REGISTER_GPU_HOST_KERNEL(string); +REGISTER_GPU_HOST_REF_KERNEL(string); +REGISTER_GPU_HOST_KERNEL(ResourceHandle); + +#undef REGISTER_GPU_HOST_KERNEL +#undef REGISTER_GPU_HOST_REF_KERNEL + +void ReturnOp::Compute(OpKernelContext* context) { + if (IsRefType(context->input_dtype(0))) { + context->forward_ref_input_to_ref_output(0, 0); + } else { + context->set_output(0, context->input(0)); + } +} + +REGISTER_KERNEL_BUILDER(Name("Return").Device(DEVICE_CPU), ReturnOp); +REGISTER_KERNEL_BUILDER(Name("RefReturn").Device(DEVICE_CPU), ReturnOp); + +#define REGISTER_GPU_KERNEL(type) \ + REGISTER_KERNEL_BUILDER( \ + Name("Return").Device(DEVICE_GPU).TypeConstraint("T"), ReturnOp); +#define REGISTER_GPU_REF_KERNEL(type) \ + REGISTER_KERNEL_BUILDER( \ + Name("RefReturn").Device(DEVICE_GPU).TypeConstraint("T"), ReturnOp); + +TF_CALL_NUMBER_TYPES_NO_INT32(REGISTER_GPU_KERNEL); +TF_CALL_NUMBER_TYPES_NO_INT32(REGISTER_GPU_REF_KERNEL); +REGISTER_GPU_KERNEL(bool); +REGISTER_GPU_REF_KERNEL(bool); + +#undef REGISTER_GPU_KERNEL +#undef REGISTER_GPU_REF_KERNEL + +#ifdef TENSORFLOW_USE_SYCL + #define REGISTER_SYCL_KERNEL(type) \ + REGISTER_KERNEL_BUILDER( \ + Name("Return").Device(DEVICE_SYCL).TypeConstraint("T"), ReturnOp); \ + REGISTER_KERNEL_BUILDER( \ + Name("RefReturn").Device(DEVICE_SYCL).TypeConstraint("T"), ReturnOp); +REGISTER_SYCL_KERNEL(bool); +TF_CALL_NUMBER_TYPES_NO_INT32(REGISTER_SYCL_KERNEL); + +#undef REGISTER_SYCL_KERNEL +#undef REGISTER_SYCL_REF_KERNEL + +#define REGISTER_SYCL_HOST_KERNEL(type) \ + REGISTER_KERNEL_BUILDER(Name("Return") \ + .Device(DEVICE_SYCL) \ + .HostMemory("data") \ + .HostMemory("output") \ + .TypeConstraint("T"), \ + ReturnOp); \ + REGISTER_KERNEL_BUILDER(Name("RefReturn") \ + .Device(DEVICE_SYCL) \ + .HostMemory("data") \ + .HostMemory("output") \ + .TypeConstraint("T"), \ + ReturnOp) + +REGISTER_SYCL_HOST_KERNEL(int32); +REGISTER_SYCL_HOST_KERNEL(string); +#undef REGISTER_SYCL_HOST_KERNEL +#endif // TENSORFLOW_USE_SYCL + +#define REGISTER_GPU_HOST_KERNEL(type) \ + REGISTER_KERNEL_BUILDER(Name("Return") \ + .Device(DEVICE_GPU) \ + .HostMemory("data") \ + .HostMemory("output") \ + .TypeConstraint("T"), \ + ReturnOp); \ + REGISTER_KERNEL_BUILDER(Name("RefReturn") \ + .Device(DEVICE_GPU) \ + .HostMemory("data") \ + .HostMemory("output") \ + .TypeConstraint("T"), \ + ReturnOp) + +REGISTER_GPU_HOST_KERNEL(int32); +REGISTER_GPU_HOST_KERNEL(string); + +#undef REGISTER_GPU_HOST_KERNEL + +} // namespace tensorflow diff --git a/tensorflow/core/kernels/function_control_ops.h b/tensorflow/core/kernels/function_control_ops.h new file mode 100644 index 00000000000000..b03d3eae9a39c5 --- /dev/null +++ b/tensorflow/core/kernels/function_control_ops.h @@ -0,0 +1,47 @@ +/* Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_KERNELS_FUNCTION_CONTROL_OPS_H_ +#define TENSORFLOW_KERNELS_FUNCTION_CONTROL_OPS_H_ + +#include "tensorflow/core/framework/op_kernel.h" + +namespace tensorflow { + +// A call op has one input and one output. It creates or finds +// the child frame that is uniquely identified by the frame_name, +// and makes its input available to the child frame. +class CallOpe : public OpKernel { +public: + explicit CallOpe(OpKernelConstruction* context) : OpKernel(context) {} + void Compute(OpKernelContext* context) override; + bool IsExpensive() override { return false; } + ~CallOpe() override {} + + TF_DISALLOW_COPY_AND_ASSIGN(CallOpe); +}; + +// A Return op has one input and one output. It exits the current +// frame to its parent frame, and makes its input available to the +// parent frame only if it receives a tensor with a specific tag. +class ReturnOp : public OpKernel { +public: + explicit ReturnOp(OpKernelConstruction* context) : OpKernel(context) {} + void Compute(OpKernelContext* context) override; + bool IsExpensive() override { return false; } + ~ReturnOp() override {} + + TF_DISALLOW_COPY_AND_ASSIGN(ReturnOp); +}; +} // namespace tensorflow + +#endif diff --git a/tensorflow/core/ops/BUILD b/tensorflow/core/ops/BUILD index c1e0497969e7fc..44fbb61eea480a 100644 --- a/tensorflow/core/ops/BUILD +++ b/tensorflow/core/ops/BUILD @@ -58,6 +58,7 @@ tf_gen_op_libs( "filesystem_ops", "function_ops", "functional_ops", + "function_control_ops", "image_ops", "io_ops", "linalg_ops", @@ -293,6 +294,7 @@ cc_library( ":experimental_dataset_ops_op_lib", ":filesystem_ops_op_lib", ":function_ops_op_lib", + ":function_control_ops_op_lib", ":functional_ops_op_lib", ":image_ops_op_lib", ":io_ops_op_lib", diff --git a/tensorflow/core/ops/function_control_ops.cc b/tensorflow/core/ops/function_control_ops.cc new file mode 100644 index 00000000000000..c337160b0fa5da --- /dev/null +++ b/tensorflow/core/ops/function_control_ops.cc @@ -0,0 +1,116 @@ +/* Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/core/framework/common_shape_fns.h" +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/shape_inference.h" + +namespace tensorflow { + +using shape_inference::InferenceContext; +using shape_inference::ShapeHandle; + +// -------------------------------------------------------------------------- +REGISTER_OP("Call") + .Input("data: T") + .Output("output: T") + .Attr("T: type") + .Attr("frame_name: string") + .Attr("call_id: int") + .Attr("arg_id: int") + .Attr("is_constant: bool = false") + .SetShapeFn([](InferenceContext* c) { + c->set_output(0, c->UnknownShape()); + + // Handle resource shape / dtype, if present. + auto* handle_data = c->input_handle_shapes_and_types(0); + if (handle_data != nullptr) { + c->set_output_handle_shapes_and_types(0, *handle_data); + } else { + // Otherwise, propagate shape if output is a constant. + bool is_constant; + TF_RETURN_IF_ERROR(c->GetAttr("is_constant", &is_constant)); + if (is_constant) { + c->set_output(0, c->input(0)); + } + } + return OkStatus(); + }) + .Doc(R"Doc( +Creates (or finds) a child frame, and makes `data` available to the child frame. + +This op is used together with `Return` to create recursive calls in the graph. +The unique `frame_name` is used by the `Executor` to identify frames. + +data: The tensor to be made available to the child frame. +frame_name: The name of the child frame. +output: The same tensor as `data`. + +Returns tensors with the same shapes and contents as the input +tensors. + )Doc"); + +REGISTER_OP("RefCall") + .Input("data: Ref(T)") + .Output("output: Ref(T)") + .Attr("T: type") + .Attr("frame_name: string") + .Attr("call_id: int") + .Attr("arg_id: int") + .Attr("is_constant: bool = false") + .SetShapeFn(shape_inference::UnchangedShape) + .Doc(R"Doc( +Creates (or finds) a child frame, and makes `data` available to the child frame. + +This op is used together with `Return` to create recursive calls in the graph. +The unique `frame_name` is used by the `Executor` to identify frames. + +data: The tensor to be made available to the child frame. +frame_name: The name of the child frame. +output: The same tensor as `data`. + +Returns tensors with the same shapes and contents as the input +tensors. + )Doc"); + +// -------------------------------------------------------------------------- +REGISTER_OP("Return") +.Input("data: T") +.Output("output: T") +.Attr("T: type") +.Attr("frame_name: string") +.Attr("call_id: int") +.Attr("arg_id: int") +.SetShapeFn(shape_inference::UnchangedShape) +.Doc(R"Doc( +Exits the current frame to its parent frame. +Exit makes its input `data` available to the parent frame. +data: The list of tensors to be made available to the parent frame. +output: The same list of tensors as `data`. + )Doc"); + +REGISTER_OP("RefReturn") +.Input("data: Ref(T)") +.Output("output: Ref(T)") +.Attr("T: type") +.Attr("frame_name: string") +.Attr("call_id: int") +.Attr("arg_id: int") +.SetShapeFn(shape_inference::UnchangedShape) +.Doc(R"Doc( +Exits the current frame to its parent frame. +Exit makes its input `data` available to the parent frame. +data: The tensors to be made available to the parent frame. +output: The same tensors as `data`. + )Doc"); + +} // namespace tensorflow diff --git a/tensorflow/core/protobuf/rewriter_config.proto b/tensorflow/core/protobuf/rewriter_config.proto index f98d1928d9e156..aedf6a8bb0fa69 100644 --- a/tensorflow/core/protobuf/rewriter_config.proto +++ b/tensorflow/core/protobuf/rewriter_config.proto @@ -92,6 +92,8 @@ message RewriterConfig { Toggle function_optimization = 10; // Strips debug-related nodes from the graph (off by default). Toggle debug_stripper = 11; + // Function transformation (default is ON). + Toggle function_transformation = 33; // If true, don't remove unnecessary ops from the graph bool disable_model_pruning = 2; // Try to allocate some independent Op outputs contiguously in order to From 3458bd98af32fee7f02548ee2bcc04fc0ce711f5 Mon Sep 17 00:00:00 2001 From: George Vasilakopoulos Date: Wed, 28 Feb 2024 00:46:54 +0000 Subject: [PATCH 02/53] Debugging --- .../core/common_runtime/graph_constructor.cc | 4 +-- tensorflow/core/grappler/optimizers/BUILD | 1 + .../optimizers/function_transformation.cc | 26 +++++++++---------- .../function_transformation.h | 0 .../grappler/optimizers/meta_optimizer.cc | 2 +- .../core/grappler/utils/topological_sort.cc | 2 +- 6 files changed, 18 insertions(+), 17 deletions(-) rename tensorflow/core/grappler/{ => optimizers}/function_transformation.h (100%) diff --git a/tensorflow/core/common_runtime/graph_constructor.cc b/tensorflow/core/common_runtime/graph_constructor.cc index b45b76f0ffc776..951c8a8c9e0f04 100644 --- a/tensorflow/core/common_runtime/graph_constructor.cc +++ b/tensorflow/core/common_runtime/graph_constructor.cc @@ -665,7 +665,7 @@ Status GraphConstructor::EnsureNoNameCollisions() { Status GraphConstructor::PopulateFunctionReturningNodes() { std::unordered_map> returning_nodes; for (int n = 0; n < node_def_count(); ++n) { - const NodeDef& node_def = get_node_def(); + const NodeDef& node_def = get_node_def(n); if (IsReturn(node_def)){ // Nodes that send their output to "Return" nodes are // function Returning Nodes and in case of recursive functions @@ -783,7 +783,7 @@ Status GraphConstructor::InitFromEdges() { } } - gt1::FlatSet call_nodes; + gtl::FlatSet call_nodes; for (int n = 0; n < node_def_count(); ++n) { const NodeDef& node_def = get_node_def(n); if (IsCall(node_def)) { diff --git a/tensorflow/core/grappler/optimizers/BUILD b/tensorflow/core/grappler/optimizers/BUILD index 1881b2496b9913..55ad6ade3326df 100644 --- a/tensorflow/core/grappler/optimizers/BUILD +++ b/tensorflow/core/grappler/optimizers/BUILD @@ -683,6 +683,7 @@ cc_library( ":dependency_optimizer", ":function_optimizer", ":generic_layout_optimizer", + ":function_transformation", ":graph_optimizer", ":implementation_selector", ":loop_optimizer", diff --git a/tensorflow/core/grappler/optimizers/function_transformation.cc b/tensorflow/core/grappler/optimizers/function_transformation.cc index 7963083e7a384f..6623b7fceaa12f 100644 --- a/tensorflow/core/grappler/optimizers/function_transformation.cc +++ b/tensorflow/core/grappler/optimizers/function_transformation.cc @@ -164,7 +164,7 @@ Status CopyArgType(const OpDef::ArgDef& arg, } *type = it->type(); } - return Status::OK(); + return OkStatus(); } // Copy input/output argument type to the type_list. Return error if argument @@ -181,7 +181,7 @@ Status CopyArgType(const OpDef::ArgDef& arg, } type_list->add_type(it->type()); } - return Status::OK(); + return OkStatus(); } struct CallInfo { @@ -314,7 +314,7 @@ Status CallRewriter::CollectCalls(std::vector& calls) { calls.push_back(call); } } - return Status::OK(); + return OkStatus(); } Status CallRewriter::AddCallOp(const CallInfo& call_info, @@ -341,7 +341,7 @@ Status CallRewriter::AddCallOp(const CallInfo& call_info, attr["arg_id"].set_i(arg_id); attr["is_constant"].set_b(false); - return Status::OK(); + return OkStatus(); } Status CallRewriter::AddRetOp(const CallInfo& call_info, @@ -363,7 +363,7 @@ Status CallRewriter::AddRetOp(const CallInfo& call_info, attr["call_id"].set_i(call_info.call_id); attr["arg_id"].set_i(arg_id); - return Status::OK(); + return OkStatus(); } Status CallRewriter::ConnectInput(NodeDef* from, NodeDef* to) { @@ -377,7 +377,7 @@ Status CallRewriter::ConnectInput(NodeDef* from, NodeDef* to) { if (to->input_size() > 1) { (*to->mutable_attr())["N"].set_i(to->input_size()); } - return Status::OK(); + return OkStatus(); } Status CallRewriter::TransformCall(CallInfo& call_info) { @@ -454,7 +454,7 @@ Status CallRewriter::TransformCall(CallInfo& call_info) { printf("Mark call %s (function %s) as transformed\n", call_info.node_name.c_str(), call_info.function_name.c_str()); MarkCallTransformed(call_info); - return Status::OK(); + return OkStatus(); } Status InlineFunction(const FunctionDef& func_def, @@ -549,7 +549,7 @@ Status InlineFunction(const FunctionDef& func_def, func_info.output_def[i] = func_def.signature().output_arg(i); } - return Status::OK(); + return OkStatus(); } // new @@ -564,7 +564,7 @@ Status CallRewriter::FindCompatibleOrInlineFunction( // possible type specialization? if (it != transformed_functions_.end()) { func_info = it->second; - return Status::OK(); + return OkStatus(); } const FunctionDef* func_def = ctx.FindInlinedFunction(func_name); if (func_def == nullptr) { @@ -576,7 +576,7 @@ Status CallRewriter::FindCompatibleOrInlineFunction( InlineFunction(*func_def, ctx, func_attr, device, graph, func_info)); transformed_functions_[func_name] = func_info; printf("Store inlined function %s\n", func_name.c_str()); - return Status::OK(); + return OkStatus(); } } // namespace @@ -588,7 +588,7 @@ Status FunctionTransformation::Optimize(Cluster* cluster, const GrapplerItem& it *output = item.graph; if (!ctx.HasInlinedFunctions()) { - return Status::OK(); + return OkStatus(); } std::vector calls; @@ -600,7 +600,7 @@ Status FunctionTransformation::Optimize(Cluster* cluster, const GrapplerItem& it for (CallInfo& call : calls) { Status s = call_rewriter.TransformCall(call); if (!s.ok()) { - printf("Error: %s\n", s.error_message().c_str()); + printf("Error: %s\n", tsl::NullTerminatedMessage(s)); return s; } printf("After transforming call %s:\n %s\n", call.function_name.c_str(), SummarizeGraphDef(*output).c_str()); @@ -643,7 +643,7 @@ Status FunctionTransformation::Optimize(Cluster* cluster, const GrapplerItem& it writer.WriteEvent(event); /******************************************************************************************************/ - return Status::OK(); + return OkStatus(); } } // end namespace grappler diff --git a/tensorflow/core/grappler/function_transformation.h b/tensorflow/core/grappler/optimizers/function_transformation.h similarity index 100% rename from tensorflow/core/grappler/function_transformation.h rename to tensorflow/core/grappler/optimizers/function_transformation.h diff --git a/tensorflow/core/grappler/optimizers/meta_optimizer.cc b/tensorflow/core/grappler/optimizers/meta_optimizer.cc index 1da252ce7a51b2..20e19fc473957a 100644 --- a/tensorflow/core/grappler/optimizers/meta_optimizer.cc +++ b/tensorflow/core/grappler/optimizers/meta_optimizer.cc @@ -496,7 +496,7 @@ Status MetaOptimizer::InitializeOptimizers( USER_IS_EXPERIMENTAL_BOTH(function_transformation)) { VLOG(2) << "function_transformation is not implemented in TFG yet"; } else { - optimizers->push_back(MakeUnique()); + optimizers->push_back(std::make_unique()); } } #ifndef ENABLE_MKL diff --git a/tensorflow/core/grappler/utils/topological_sort.cc b/tensorflow/core/grappler/utils/topological_sort.cc index 5f4be7975449ce..e6b70c3c2c4fe5 100644 --- a/tensorflow/core/grappler/utils/topological_sort.cc +++ b/tensorflow/core/grappler/utils/topological_sort.cc @@ -98,7 +98,7 @@ Status ComputeTopologicalOrder( // if (!StringPiece(graph.node(input)).starts_with("^")) { if (true) { int call_id; - GetNodeAttr(graph.node(i), "call_id", &call_id); + TF_CHECK_OK(GetNodeAttr(graph.node(i), "call_id", &call_id)); returning_nodes[input].emplace(call_id); } } From cd23277b30b6ffaa6295379f9f92f64f826f01ba Mon Sep 17 00:00:00 2001 From: George Vasilakopoulos Date: Wed, 28 Feb 2024 07:34:09 +0000 Subject: [PATCH 03/53] Debugging --- tensorflow/core/common_runtime/graph_constructor.cc | 2 +- .../core/grappler/optimizers/function_transformation.cc | 8 ++++---- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/tensorflow/core/common_runtime/graph_constructor.cc b/tensorflow/core/common_runtime/graph_constructor.cc index 951c8a8c9e0f04..bc3c98787bd685 100644 --- a/tensorflow/core/common_runtime/graph_constructor.cc +++ b/tensorflow/core/common_runtime/graph_constructor.cc @@ -685,7 +685,7 @@ Status GraphConstructor::PopulateFunctionReturningNodes() { int call_id; - GetNodeAttr(AttrSlice(node_def), "call_id", &call_id); + TF_CHECK_OK(GetNodeAttr(AttrSlice(node_def), "call_id", &call_id)); returning_nodes[prevNode].emplace(call_id); } } diff --git a/tensorflow/core/grappler/optimizers/function_transformation.cc b/tensorflow/core/grappler/optimizers/function_transformation.cc index 6623b7fceaa12f..b4b2e35c67e245 100644 --- a/tensorflow/core/grappler/optimizers/function_transformation.cc +++ b/tensorflow/core/grappler/optimizers/function_transformation.cc @@ -395,11 +395,11 @@ Status CallRewriter::TransformCall(CallInfo& call_info) { call_nodes.resize(func_info.inputs.size()); for (unsigned int arg_num = 0; arg_num < func_info.inputs.size(); arg_num++) { call_nodes[arg_num] = graph->add_node(); - AddCallOp(call_info, + TF_CHECK_OK(AddCallOp(call_info, func_info.input_def[arg_num], call_info.input_nodes[arg_num], arg_num, - call_nodes[arg_num]); + call_nodes[arg_num])); call_nodes[arg_num]->set_device(call_info.device); @@ -410,11 +410,11 @@ Status CallRewriter::TransformCall(CallInfo& call_info) { ret_nodes.resize(func_info.outputs.size()); for (unsigned int out_port = 0; out_port < func_info.outputs.size(); out_port++) { ret_nodes[out_port] = graph->add_node(); - AddRetOp(call_info, + TF_CHECK_OK(AddRetOp(call_info, func_info.output_def[out_port], func_info.outputs[out_port], out_port, - ret_nodes[out_port]); + ret_nodes[out_port])); ret_nodes[out_port]->set_device(call_info.device); } From acb668eca3977fa518d5284eaf40bb3db8ab7477 Mon Sep 17 00:00:00 2001 From: George Vasilakopoulos Date: Tue, 5 Mar 2024 20:53:59 +0200 Subject: [PATCH 04/53] Added TESTS --- TESTS/factorial.py | 29 +++++++++++++++++++++++++++++ 1 file changed, 29 insertions(+) create mode 100644 TESTS/factorial.py diff --git a/TESTS/factorial.py b/TESTS/factorial.py new file mode 100644 index 00000000000000..25542860f9b6ea --- /dev/null +++ b/TESTS/factorial.py @@ -0,0 +1,29 @@ +import tensorflow as tf +from tensorflow.python.framework import function + +fac = function.Declare("Fac", [("n", tf.int32)], [("ret", tf.int32)]) + +@function.Defun(tf.int32, func_name="Fac", out_names=["ret"]) +def FacImpl(n): + return tf.cond(tf.less_equal(n, 1), + lambda: tf.constant(1), + lambda: n * fac(n - 1)) + + +FacImpl.add_to_graph(tf.get_default_graph()) + +n = tf.placeholder(tf.int32, shape=[]) +x = tf.add(n, 1) +result = fac(x) +y = tf.add(result, 1) + +#print(tf.get_default_graph().as_graph_def()) + +writer = tf.summary.FileWriter('./graphs', tf.get_default_graph()) + +sess = tf.Session() +print(sess.run(y, feed_dict={n: 5})) + +writer.close() + +sess.close() From 49cc2e5ae8fd356515c5c0b96b61ec67a7b1a14d Mon Sep 17 00:00:00 2001 From: George Vasilakopoulos Date: Tue, 12 Mar 2024 20:06:48 +0000 Subject: [PATCH 05/53] Container backup 1203 --- TESTS/comp | 7 ++ TESTS/factorial.py | 67 ++++++++--- TESTS/nohup.out | 0 TESTS/test.py | 35 ++++++ .../revived_types/flat_tensor_function.cc | 2 + tensorflow/core/framework/op.cc | 6 + .../optimizers/function_transformation.cc | 18 +++ .../grappler/optimizers/meta_optimizer.cc | 9 +- .../optimizers/my_bad_transformation.cc | 113 ++++++++++++++++++ tensorflow/python/autograph/impl/api.py | 6 + tensorflow/python/framework/function.py | 56 +++++++++ 11 files changed, 298 insertions(+), 21 deletions(-) create mode 100755 TESTS/comp create mode 100644 TESTS/nohup.out create mode 100644 TESTS/test.py create mode 100644 tensorflow/core/grappler/optimizers/my_bad_transformation.cc diff --git a/TESTS/comp b/TESTS/comp new file mode 100755 index 00000000000000..652ee962c7f8bb --- /dev/null +++ b/TESTS/comp @@ -0,0 +1,7 @@ +#!/bin/bash + +pip uninstall tensorflow -y +cd .. +bazel build //tensorflow/tools/pip_package:build_pip_package +./bazel-bin/tensorflow/tools/pip_package/build_pip_package /mnt +pip install /mnt/tensorflow-*.whl diff --git a/TESTS/factorial.py b/TESTS/factorial.py index 25542860f9b6ea..16c6d8f5c06ded 100644 --- a/TESTS/factorial.py +++ b/TESTS/factorial.py @@ -1,29 +1,60 @@ import tensorflow as tf -from tensorflow.python.framework import function -fac = function.Declare("Fac", [("n", tf.int32)], [("ret", tf.int32)]) -@function.Defun(tf.int32, func_name="Fac", out_names=["ret"]) -def FacImpl(n): - return tf.cond(tf.less_equal(n, 1), - lambda: tf.constant(1), - lambda: n * fac(n - 1)) +# Verbosity is now 5 +log_dir = "./graph" -FacImpl.add_to_graph(tf.get_default_graph()) +tf.config.run_functions_eagerly(False) +writer = tf.summary.create_file_writer(log_dir) +# tf.autograph.set_verbosity(5,True) -n = tf.placeholder(tf.int32, shape=[]) -x = tf.add(n, 1) -result = fac(x) -y = tf.add(result, 1) +# @tf.function +# def fact(x): +# if(x == 1): +# return 1 +# else: +# return x * fact(x-1) -#print(tf.get_default_graph().as_graph_def()) -writer = tf.summary.FileWriter('./graphs', tf.get_default_graph()) +def fact_loop(n): + i = tf.constant(1) + result = tf.constant(1) -sess = tf.Session() -print(sess.run(y, feed_dict={n: 5})) + # Loop condition: continue while i <= n + def cond(i, result): + return i <= n -writer.close() + # Loop body: multiply result by i and increment i + def body(i, result): + return i + 1, result * i -sess.close() + # Execute the while loop + _, final_result = tf.while_loop(cond, body, loop_vars=[i, result]) + return final_result + +# def fact(n): +# condition = tf.equal(n, tf.constant(1)) +# true_branch = tf.constant(1) +# subtraction_op = tf.subtract(n, 1) +# false_branch = tf.multiply(fact(subtraction_op), n) +# return tf.cond(condition, lambda: true_branch, lambda: false_branch) + +transformed_f = tf.function(fact_loop) + +result = transformed_f(3) +print(result) +# print(tf.autograph.to_code(transformed_f)) +# graph_f = tf.autograph.to_graph(fact) + + +# print(graph_f) + + + + +tf.summary.trace_on(graph=True, profiler=False) +# result = transformed_f(tf.constant(3)) +# print(result) +with writer.as_default(): + tf.summary.trace_export("my_function_trace", step=0, profiler_outdir=log_dir) diff --git a/TESTS/nohup.out b/TESTS/nohup.out new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/TESTS/test.py b/TESTS/test.py new file mode 100644 index 00000000000000..55bcdc66b211c6 --- /dev/null +++ b/TESTS/test.py @@ -0,0 +1,35 @@ +import tensorflow as tf +from tensorflow.python.framework import function + +log_dir = "./graph" + +tf.config.run_functions_eagerly(False) + +fac = function.Declare("Fac", [("n", tf.int32)], [("ret", tf.int32)]) +writer = tf.summary.create_file_writer(log_dir) + + +def FacImpl(n): + condition = tf.equal(n, tf.constant(1)) + true_branch = tf.constant(1) + subtraction_op = tf.subtract(n, 1) + false_branch = tf.multiply(fac(subtraction_op), n) + return tf.cond(condition, lambda: true_branch, lambda: false_branch) + +transformed_function = tf.function(FacImpl) + +# print(tf.get_default_graph().as_graph_def()) +# print(FacImpl(3)) + + +# print(tf.autograph.to_code(FacImpl)) +print(transformed_function.get_concrete_function(tf.constant(1)).graph.as_graph_def()) + + +# tf.summary.trace_on(graph=True, profiler=False) +# # result = transformed_f(tf.constant(3)) +# # print(result) +# with writer.as_default(): +# tf.summary.trace_export("my_function_trace", step=0, profiler_outdir=log_dir) + + diff --git a/tensorflow/c/experimental/saved_model/core/revived_types/flat_tensor_function.cc b/tensorflow/c/experimental/saved_model/core/revived_types/flat_tensor_function.cc index d6e568090f7d27..ec8d4ad31d0f44 100644 --- a/tensorflow/c/experimental/saved_model/core/revived_types/flat_tensor_function.cc +++ b/tensorflow/c/experimental/saved_model/core/revived_types/flat_tensor_function.cc @@ -13,6 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include + #include "tensorflow/c/experimental/saved_model/core/revived_types/flat_tensor_function.h" #include diff --git a/tensorflow/core/framework/op.cc b/tensorflow/core/framework/op.cc index ccd5edcb3d37b5..4837dbe5fc3b2d 100644 --- a/tensorflow/core/framework/op.cc +++ b/tensorflow/core/framework/op.cc @@ -15,6 +15,8 @@ limitations under the License. #include "tensorflow/core/framework/op.h" +#include + #include #include #include @@ -79,6 +81,10 @@ Status OpNotFound(const string& op_type_name) { Status OpRegistry::LookUp(const string& op_type_name, const OpRegistrationData** op_reg_data) const { if ((*op_reg_data = LookUp(op_type_name))) return OkStatus(); + std::ofstream outputFile("/tensorflow/TESTS/mylog.txt", std::ios::app); + outputFile << "Searching for " << op_type_name << std::endl; + + outputFile.close(); return OpNotFound(op_type_name); } diff --git a/tensorflow/core/grappler/optimizers/function_transformation.cc b/tensorflow/core/grappler/optimizers/function_transformation.cc index b4b2e35c67e245..465d0a844a9730 100644 --- a/tensorflow/core/grappler/optimizers/function_transformation.cc +++ b/tensorflow/core/grappler/optimizers/function_transformation.cc @@ -11,6 +11,9 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include +// std::ofstream outputFile("/tensorflow/TESTS/mylog.txt"); + #include "tensorflow/core/grappler/optimizers/function_transformation.h" #include #include @@ -77,7 +80,11 @@ class FunctionInliningContext { private: std::unordered_map InliningCandidates(const GrapplerItem& item) const { std::unordered_map functions; + + + // outputFile << "In inliningcandidates " << SummarizeGraphDef(item.graph)<< std::endl; for (const FunctionDef& func : item.graph.library().function()) { + // outputFile << func.signature().name() << std::endl; // Don't inline functions marked as noinline // if (func.attr().count("_noinline") != 0) { // continue; @@ -96,6 +103,7 @@ class FunctionInliningContext { } functions[func.signature().name()] = &func; } + // outputFile << "Returning!"<& calls) { // identify and collect calls in the graph + // outputFile << "In collect calls: "<< std::endl; for (NodeDef& node : *graph->mutable_node()) { const FunctionDef* func = ctx.FindInlinedFunction(node.op()); + // outputFile << "Collecting Calls: "<< node.name() << " " << node.op() << std::endl; if (func != nullptr) { CallInfo call; call.call_id = GetCallId(node); @@ -583,11 +593,18 @@ Status CallRewriter::FindCompatibleOrInlineFunction( Status FunctionTransformation::Optimize(Cluster* cluster, const GrapplerItem& item, GraphDef* output) { + + + + // outputFile << "In Optimize" << std::endl; + + FunctionInliningContext ctx(item); CallRewriter call_rewriter(item, output, ctx); *output = item.graph; if (!ctx.HasInlinedFunctions()) { + // outputFile << "No inlining functions!"<mutable_versions() = item.graph.versions(); diff --git a/tensorflow/core/grappler/optimizers/meta_optimizer.cc b/tensorflow/core/grappler/optimizers/meta_optimizer.cc index 20e19fc473957a..5b936537a7510e 100644 --- a/tensorflow/core/grappler/optimizers/meta_optimizer.cc +++ b/tensorflow/core/grappler/optimizers/meta_optimizer.cc @@ -212,6 +212,12 @@ std::unique_ptr MetaOptimizer::MakeNewOptimizer( cfg_.use_plugin_optimizers() != RewriterConfig::OFF, device_types); if (optimizer == "pruning" && !plugin_configs.disable_model_pruning) return std::unique_ptr(new ModelPruner()); + + + // if (LowerControlFlow()) { + MK_OPT("function_transformation", "function_transformation", new FunctionTransformation()); + // } + MK_OPT("function", "function_optimization", new FunctionOptimizer(cfg_.function_optimization(), /*lower_control_flow=*/LowerControlFlow())); @@ -262,9 +268,6 @@ std::unique_ptr MetaOptimizer::MakeNewOptimizer( MK_OPT("pin_to_host", "pin_to_host_optimization", new PinToHostOptimizer(cfg_.pin_to_host_optimization())); - if (LowerControlFlow()) { - MK_OPT("function_transformation", "function_transformation", new FunctionTransformation()); - } return std::unique_ptr(); } diff --git a/tensorflow/core/grappler/optimizers/my_bad_transformation.cc b/tensorflow/core/grappler/optimizers/my_bad_transformation.cc new file mode 100644 index 00000000000000..e315e3b12b08ad --- /dev/null +++ b/tensorflow/core/grappler/optimizers/my_bad_transformation.cc @@ -0,0 +1,113 @@ +/* Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/grappler/optimizers/function_transformation.h" +#include +#include +#include +#include "tensorflow/core/util/event.pb.h" +#include "tensorflow/core/util/events_writer.h" + +#include "tensorflow/core/common_runtime/graph_constructor.h" +#include "tensorflow/core/graph/tensor_id.h" +#include "tensorflow/core/grappler/grappler_item.h" +#include "tensorflow/core/grappler/op_types.h" +#include "tensorflow/core/grappler/utils.h" +#include "tensorflow/core/grappler/utils/functions.h" +#include "tensorflow/core/framework/function.h" +#include "tensorflow/core/framework/function.pb.h" +#include "tensorflow/core/framework/graph_def_util.h" +#include "tensorflow/core/framework/node_def.pb.h" +#include "tensorflow/core/framework/node_def_util.h" +#include "tensorflow/core/framework/op_def.pb.h" +#include "tensorflow/core/framework/versions.pb.h" + +namespace tensorflow { +namespace grappler { +namespace { + + +class CallRewriter { + + +}; + + + +Status FunctionTransformation::Optimize(Cluster* cluster, const GrapplerItem& item, + GraphDef* output) { + FunctionInliningContext ctx(item); + CallRewriter call_rewriter(item, output, ctx); + + *output = item.graph; + if (!ctx.HasInlinedFunctions()) { + return OkStatus(); + } + + std::vector calls; + while (1) { + TF_RETURN_IF_ERROR(call_rewriter.CollectCalls(calls)); + if (calls.empty()) { + break; + } + for (CallInfo& call : calls) { + Status s = call_rewriter.TransformCall(call); + if (!s.ok()) { + printf("Error: %s\n", tsl::NullTerminatedMessage(s)); + return s; + } + printf("After transforming call %s:\n %s\n", call.function_name.c_str(), SummarizeGraphDef(*output).c_str()); + } + calls.clear(); + call_rewriter.Flush(); + } + call_rewriter.Flush(); + printf("After finalizing:\n %s\n", SummarizeGraphDef(*output).c_str()); + *output->mutable_versions() = item.graph.versions(); + + // Function Library should be pruned of unreachable function definitions + // cf. https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/grappler/optimizers/function_optimizer.cc#L428 + // however in this version there is a check in meta_optimizer that guarantees + // that function library remains of the same length + // cf. https://github.com/acharal/tensorflow/blob/r1.4_recursion/tensorflow/core/grappler/optimizers/meta_optimizer.cc#L132 + *output->mutable_library() = item.graph.library(); + + + + /******************************************************************************************************/ + // Dumps optimized graph in a not so readable form + // const GraphDef* tmp = optimized_graph; + // printf("Summarize Optimized Graph\n %s\n", SummarizeGraphDef(*tmp).c_str()); + // Write an event, so that we can visualize this optimized graph in tensorboard + EventsWriter writer("TRANSFORMATION"); + Event event; + event.set_wall_time(1234); + event.set_step(34); + const size_t proto_size = output->ByteSizeLong(); + void* buf = port::Malloc(proto_size); + if (buf == nullptr) { + return errors::ResourceExhausted( + "Failed to allocate memory to serialize message of type '" , + output->GetTypeName(), "' and size ", proto_size); + } + output->SerializeToArray(buf, proto_size); + const void* bf = buf; + event.set_graph_def(bf, proto_size); + writer.WriteEvent(event); + /******************************************************************************************************/ + + return OkStatus(); +} + +} // end namespace grappler +} // end namespace tensorflow diff --git a/tensorflow/python/autograph/impl/api.py b/tensorflow/python/autograph/impl/api.py index b60c457599e717..c89e995bd461b8 100644 --- a/tensorflow/python/autograph/impl/api.py +++ b/tensorflow/python/autograph/impl/api.py @@ -320,6 +320,8 @@ def converted_call(f, args, kwargs, caller_fn_scope=None, options=None): """ logging.log(1, 'Converted call: %s\n args: %s\n kwargs: %s\n', f, args, kwargs) + + if options is None: if caller_fn_scope is None: @@ -338,6 +340,10 @@ def converted_call(f, args, kwargs, caller_fn_scope=None, options=None): logging.log(2, 'Permanently allowed: %s: AutoGraph artifact', f) return _call_unconverted(f, args, kwargs, options) + + + + # If this is a partial, unwrap it and redo all the checks. if isinstance(f, functools.partial): new_kwargs = {} diff --git a/tensorflow/python/framework/function.py b/tensorflow/python/framework/function.py index 848a4c8f23599f..8fb6efa0bab7bd 100644 --- a/tensorflow/python/framework/function.py +++ b/tensorflow/python/framework/function.py @@ -21,6 +21,7 @@ import collections import hashlib +from tensorflow.core.framework import op_def_pb2 from tensorflow.core.framework import attr_value_pb2 from tensorflow.core.framework import function_pb2 from tensorflow.python.client import pywrap_tf_session as c_api @@ -245,6 +246,61 @@ def __del__(self): # been unloaded. Will catch other module unloads as well. +class Declare(object): + """Declares a TensorFlow function. + + The object represents a TensorFlow function which will be defined + later during a graph construction. + + For example, + # Declares a function Foo, which takes a tf.int32 named "n" and a + # tf.float32 named "x" as inputs and returns a tf.float32 named "z" + # as its output. + foo = Declare("Foo", [("n", tf.int32), ("x", tf.float32)], + [("z", tf.float32)]) + + # Defines a function Bar calls Foo. + @tf.Defun(tf.float32) + def Bar(x): + return foo(6, x) + + # Defines Foo, with output named "z". + @tf.Defun(tf.int32, tf.float32, out_names=["z"]) + def Foo(n, x): + ... # Calculation. + return result + """ + + + def __init__(self, func_name, inputs, outputs): + """Creates a `Declare` object. + + Args: + func_name: The name of the function. + inputs: A list of (name, data type) pairs of function arguments. + outputs: A list of (name, data type) pairs of function return values. + """ + self._sig = op_def_pb2.OpDef() + self._sig.name = func_name + + def _to_argdef_list(args): + names = [n for n, t in args] + if len(names) != len(set(names)): + raise ValueError("Expected names to all be unique: %s" % str(names)) + return [ + op_def_pb2.OpDef.ArgDef(type=t.as_datatype_enum, name=n) + for n, t in args + ] + + self._sig.input_arg.extend(_to_argdef_list(inputs)) + self._sig.output_arg.extend(_to_argdef_list(outputs)) + + def __call__(self, *inputs, **kwargs): + inputs = [ops.convert_to_tensor(_) for _ in inputs] + return _call(self._sig, *inputs, **kwargs)[0] + + + class _DefinedFunction(object): """_DefinedFunction encapsulates a function definition and its properties. From cb3b605bd4518e81c86c60b4d58ff0d10f268e13 Mon Sep 17 00:00:00 2001 From: George Vasilakopoulos Date: Fri, 29 Mar 2024 17:06:11 +0000 Subject: [PATCH 06/53] Logging & TESTS --- TESTS/comp | 6 +-- TESTS/factorial.py | 61 ++++----------------------- tensorflow/core/framework/function.cc | 9 +++- tensorflow/core/framework/op.cc | 4 -- 4 files changed, 19 insertions(+), 61 deletions(-) diff --git a/TESTS/comp b/TESTS/comp index 652ee962c7f8bb..5e3449fb10e6e9 100755 --- a/TESTS/comp +++ b/TESTS/comp @@ -1,7 +1,7 @@ #!/bin/bash -pip uninstall tensorflow -y cd .. -bazel build //tensorflow/tools/pip_package:build_pip_package -./bazel-bin/tensorflow/tools/pip_package/build_pip_package /mnt +bazel build --disk_cache=~/mycache --config=dbg //tensorflow/tools/pip_package:build_pip_package && +./bazel-bin/tensorflow/tools/pip_package/build_pip_package /mnt && +pip uninstall tensorflow -y && pip install /mnt/tensorflow-*.whl diff --git a/TESTS/factorial.py b/TESTS/factorial.py index 16c6d8f5c06ded..f5ca7d668de073 100644 --- a/TESTS/factorial.py +++ b/TESTS/factorial.py @@ -1,60 +1,15 @@ import tensorflow as tf +from tensorflow.python.framework import function +# fac = function.Declare("Fac", [("n", tf.int32)], [("ret", tf.int32)]) -# Verbosity is now 5 +@function.Defun(tf.int32, func_name="Fac", out_names=["ret"]) +def FacImpl(n): + return tf.cond(tf.less_equal(n, 1), + lambda: tf.constant(1), + lambda: n * FacImpl(n - 1)) -log_dir = "./graph" -tf.config.run_functions_eagerly(False) -writer = tf.summary.create_file_writer(log_dir) -# tf.autograph.set_verbosity(5,True) +print(FacImpl(5)) -# @tf.function -# def fact(x): -# if(x == 1): -# return 1 -# else: -# return x * fact(x-1) - -def fact_loop(n): - i = tf.constant(1) - result = tf.constant(1) - - # Loop condition: continue while i <= n - def cond(i, result): - return i <= n - - # Loop body: multiply result by i and increment i - def body(i, result): - return i + 1, result * i - - # Execute the while loop - _, final_result = tf.while_loop(cond, body, loop_vars=[i, result]) - return final_result - -# def fact(n): -# condition = tf.equal(n, tf.constant(1)) -# true_branch = tf.constant(1) -# subtraction_op = tf.subtract(n, 1) -# false_branch = tf.multiply(fact(subtraction_op), n) -# return tf.cond(condition, lambda: true_branch, lambda: false_branch) - -transformed_f = tf.function(fact_loop) - -result = transformed_f(3) -print(result) -# print(tf.autograph.to_code(transformed_f)) -# graph_f = tf.autograph.to_graph(fact) - - -# print(graph_f) - - - - -tf.summary.trace_on(graph=True, profiler=False) -# result = transformed_f(tf.constant(3)) -# print(result) -with writer.as_default(): - tf.summary.trace_export("my_function_trace", step=0, profiler_outdir=log_dir) diff --git a/tensorflow/core/framework/function.cc b/tensorflow/core/framework/function.cc index 0b6bacd94af0d9..f3ad99ac3322a8 100644 --- a/tensorflow/core/framework/function.cc +++ b/tensorflow/core/framework/function.cc @@ -16,7 +16,7 @@ limitations under the License. #include "tensorflow/core/framework/function.h" #include - +#include #include #include #include @@ -1716,6 +1716,13 @@ Status FunctionLibraryDefinition::LookUp( const string& op, const OpRegistrationData** op_reg_data) const { tf_shared_lock l(mu_); auto iter = records_.find(op); + std::ofstream outputFile("/tensorflow/TESTS/mylog.txt", std::ios::app); + outputFile << "Searching for " << op << std::endl; + for(auto s : ListFunctionNames()){ + outputFile << "Function Name: " << s << std::endl; + } + + outputFile.close(); if (iter != records_.end()) { *op_reg_data = &iter->second->op_registration_data(); return OkStatus(); diff --git a/tensorflow/core/framework/op.cc b/tensorflow/core/framework/op.cc index 4837dbe5fc3b2d..6cfec232612747 100644 --- a/tensorflow/core/framework/op.cc +++ b/tensorflow/core/framework/op.cc @@ -81,10 +81,6 @@ Status OpNotFound(const string& op_type_name) { Status OpRegistry::LookUp(const string& op_type_name, const OpRegistrationData** op_reg_data) const { if ((*op_reg_data = LookUp(op_type_name))) return OkStatus(); - std::ofstream outputFile("/tensorflow/TESTS/mylog.txt", std::ios::app); - outputFile << "Searching for " << op_type_name << std::endl; - - outputFile.close(); return OpNotFound(op_type_name); } From 1a59ea6d9dfa277a510df553ddc4eba4c12629d8 Mon Sep 17 00:00:00 2001 From: George Vasilakopoulos Date: Fri, 29 Mar 2024 17:07:22 +0000 Subject: [PATCH 07/53] Allow op_def_ to be null --- tensorflow/core/framework/node_def_builder.cc | 6 ++++-- tensorflow/core/graph/node_builder.cc | 10 ++++++---- 2 files changed, 10 insertions(+), 6 deletions(-) diff --git a/tensorflow/core/framework/node_def_builder.cc b/tensorflow/core/framework/node_def_builder.cc index fcf73e6970bb5c..388591a7d6f18f 100644 --- a/tensorflow/core/framework/node_def_builder.cc +++ b/tensorflow/core/framework/node_def_builder.cc @@ -45,8 +45,10 @@ NodeDefBuilder::NodeDefBuilder(StringPiece name, StringPiece op_name, if (status.ok()) { Initialize(); } else { - errors_.push_back(std::string(status.message())); + // errors_.push_back(std::string(status.message())); + // inputs_specified_ = 0; inputs_specified_ = 0; + node_def_.set_op(string(op_name)); } if (debug != nullptr) MergeDebugInfo(*debug, &node_def_); } @@ -260,7 +262,7 @@ Status NodeDefBuilder::Finalize(NodeDef* node_def, bool consume) { } // Add default values for unspecified attrs. - AddDefaultsToNodeDef(*op_def_, node_def); + if(op_def_ != nullptr) AddDefaultsToNodeDef(*op_def_, node_def); return OkStatus(); } diff --git a/tensorflow/core/graph/node_builder.cc b/tensorflow/core/graph/node_builder.cc index dbd8fafd1ea523..ca32cad845ddc2 100644 --- a/tensorflow/core/graph/node_builder.cc +++ b/tensorflow/core/graph/node_builder.cc @@ -134,10 +134,12 @@ Status NodeBuilder::Finalize(Graph* graph, Node** created_node, bool consume) { NodeDef node_def; TF_RETURN_IF_ERROR(def_builder_.Finalize(&node_def, consume)); - TF_RETURN_IF_ERROR(ValidateNodeDef(node_def, def_builder_.op_def())); - TF_RETURN_IF_ERROR( - CheckOpDeprecation(def_builder_.op_def(), graph->versions().producer())); - + + if(&def_builder_.op_def() != nullptr){ + TF_RETURN_IF_ERROR(ValidateNodeDef(node_def, def_builder_.op_def())); + TF_RETURN_IF_ERROR( + CheckOpDeprecation(def_builder_.op_def(), graph->versions().producer())); + } TF_ASSIGN_OR_RETURN(Node * node, graph->AddNode(std::move(node_def))); node->set_assigned_device_name(assigned_device_); From afa9c1e784e3909168dd3d4aa92e0b3a3783db21 Mon Sep 17 00:00:00 2001 From: George Vasilakopoulos Date: Sat, 6 Apr 2024 18:56:35 +0000 Subject: [PATCH 08/53] functionality for Forward Function Declaration --- TESTS/test.py | 38 +++++++++---------- TESTS/test2.py | 38 +++++++++++++++++++ tensorflow/c/c_api.h | 6 +++ tensorflow/c/c_api_function.cc | 18 +++++++++ .../python/client/_pywrap_tf_session.pyi | 1 + .../python/client/tf_session_wrapper.cc | 15 ++++++++ tensorflow/python/framework/ops.py | 15 ++++++++ 7 files changed, 112 insertions(+), 19 deletions(-) create mode 100644 TESTS/test2.py diff --git a/TESTS/test.py b/TESTS/test.py index 55bcdc66b211c6..0cd496e37ce579 100644 --- a/TESTS/test.py +++ b/TESTS/test.py @@ -1,35 +1,35 @@ +# import os + +# os.environ['PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION'] = 'python' + + import tensorflow as tf from tensorflow.python.framework import function -log_dir = "./graph" -tf.config.run_functions_eagerly(False) +tf.compat.v1.disable_eager_execution() +# tf.logging.set_verbosity(tf.logging.INFO) fac = function.Declare("Fac", [("n", tf.int32)], [("ret", tf.int32)]) -writer = tf.summary.create_file_writer(log_dir) -def FacImpl(n): - condition = tf.equal(n, tf.constant(1)) - true_branch = tf.constant(1) - subtraction_op = tf.subtract(n, 1) - false_branch = tf.multiply(fac(subtraction_op), n) - return tf.cond(condition, lambda: true_branch, lambda: false_branch) +@function.Defun(tf.int32, func_name="Test", out_names=["ret"]) +def t(n): + return tf.constant(1) + -transformed_function = tf.function(FacImpl) -# print(tf.get_default_graph().as_graph_def()) -# print(FacImpl(3)) +# fac = function.Declare("Fac", [("n", tf.int32)], [("ret", tf.int32)]) +@function.Defun(tf.int32, func_name="Fac", out_names=["ret"]) +def FacImpl(n): + return t(n) -# print(tf.autograph.to_code(FacImpl)) -print(transformed_function.get_concrete_function(tf.constant(1)).graph.as_graph_def()) -# tf.summary.trace_on(graph=True, profiler=False) -# # result = transformed_f(tf.constant(3)) -# # print(result) -# with writer.as_default(): -# tf.summary.trace_export("my_function_trace", step=0, profiler_outdir=log_dir) +FacImpl.add_to_graph(tf.compat.v1.get_default_graph()) +with tf.compat.v1.Session() as sess: + result = FacImpl(1) + print("Result:", sess.run(result)) diff --git a/TESTS/test2.py b/TESTS/test2.py new file mode 100644 index 00000000000000..fff5ad068bca0a --- /dev/null +++ b/TESTS/test2.py @@ -0,0 +1,38 @@ +# import os + +# os.environ['PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION'] = 'python' + + +import tensorflow as tf +from tensorflow.python.framework import function + + + +tf.compat.v1.disable_eager_execution() + +# tf.logging.set_verbosity(tf.logging.INFO) +fac = function.Declare("Fac", [("n", tf.int32)], [("ret", tf.int32)]) + + +# @function.Defun(tf.int32, func_name="Test", out_names=["ret"]) +# def t(n): +# return tf.constant(1) + + + +# fac = function.Declare("Fac", [("n", tf.int32)], [("ret", tf.int32)]) + +@function.Defun(tf.int32, func_name="Fac", out_names=["ret"]) +def FacImpl(n): + return tf.cond(tf.less_equal(n, 1), + lambda: tf.constant(1), + lambda: n * fac(n - 1)) + + + +FacImpl.add_to_graph(tf.compat.v1.get_default_graph()) + +with tf.compat.v1.Session() as sess: + result = FacImpl(1) + print("Result:", sess.run(result)) + diff --git a/tensorflow/c/c_api.h b/tensorflow/c/c_api.h index 9812b0a7dfcef3..14992dda6b44ca 100644 --- a/tensorflow/c/c_api.h +++ b/tensorflow/c/c_api.h @@ -27,6 +27,7 @@ limitations under the License. #include "tensorflow/c/tf_tensor.h" #include "tensorflow/c/tf_tstring.h" +#include "tensorflow/core/framework/function.pb.h" // -------------------------------------------------------------------------- // C API for TensorFlow. // @@ -860,6 +861,11 @@ TF_CAPI_EXPORT extern void TF_GraphImportGraphDef( TF_Graph* graph, const TF_Buffer* graph_def, const TF_ImportGraphDefOptions* options, TF_Status* status); + + +TF_CAPI_EXPORT extern void TF_GraphAddFunctionDef(TF_Graph* g, const void* proto, size_t proto_len, TF_Status* status); + + // Adds a copy of function `func` and optionally its gradient function `grad` // to `g`. Once `func`/`grad` is added to `g`, it can be called by creating // an operation using the function's name. diff --git a/tensorflow/c/c_api_function.cc b/tensorflow/c/c_api_function.cc index 25805954eff67c..a11cdf28b33964 100644 --- a/tensorflow/c/c_api_function.cc +++ b/tensorflow/c/c_api_function.cc @@ -252,6 +252,24 @@ const char* TF_FunctionName(TF_Function* func) { return func->record->fdef().signature().name().c_str(); } + +void TF_GraphAddFunctionDef(TF_Graph* g, const void* proto, size_t proto_len, TF_Status* status){ + + tensorflow::mutex_lock l(g->mu); + tensorflow::FunctionDef fdef; + bool success = fdef.ParseFromArray(proto, proto_len); + if (!success) { + status->status = InvalidArgument( + "Invalid FunctionDef given to TF_GraphAddFunctionDef"); + return; + } + + + tensorflow::StackTracesMap stack_traces; + status->status = g->graph.AddFunctionDef(fdef,std::move(stack_traces)); +} + + void TF_GraphCopyFunction(TF_Graph* g, const TF_Function* func, const TF_Function* grad, TF_Status* status) { if (func == nullptr) { diff --git a/tensorflow/python/client/_pywrap_tf_session.pyi b/tensorflow/python/client/_pywrap_tf_session.pyi index 14645b34c5f5be..ba00dd680250da 100644 --- a/tensorflow/python/client/_pywrap_tf_session.pyi +++ b/tensorflow/python/client/_pywrap_tf_session.pyi @@ -370,6 +370,7 @@ def TF_GetOpList(arg0: TF_Library) -> object: ... def TF_GetRegisteredKernelsForOp(arg0: str) -> TF_Buffer: ... def TF_GetXlaAutoJitEnabled() -> int: ... def TF_GetXlaConstantFoldingDisabled() -> int: ... +def TF_GraphAddFunctionDef(arg0: PyGraph, arg1: bytes) -> None: ... def TF_GraphCopyFunction(arg0: PyGraph, arg1: TF_Function, arg2: TF_Function) -> None: ... def TF_GraphImportGraphDefWithResults(arg0: PyGraph, arg1: TF_Buffer, arg2: TF_ImportGraphDefOptions) -> TF_ImportGraphDefResults: ... def TF_GraphImportGraphDefWithResultsNoSerialization(arg0: PyGraph, arg1, arg2: TF_ImportGraphDefOptions) -> TF_ImportGraphDefResults: ... diff --git a/tensorflow/python/client/tf_session_wrapper.cc b/tensorflow/python/client/tf_session_wrapper.cc index b2d3492f99dfd5..a2ef54ec9d6ec6 100644 --- a/tensorflow/python/client/tf_session_wrapper.cc +++ b/tensorflow/python/client/tf_session_wrapper.cc @@ -1867,6 +1867,21 @@ PYBIND11_MODULE(_pywrap_tf_session, m) { tensorflow::MaybeRaiseRegisteredFromTFStatusWithGIL(status.get()); }); + m.def("TF_GraphAddFunctionDef", + [](PyGraph* graph, py::bytes proto) { + tensorflow::Safe_TF_StatusPtr status = + tensorflow::make_safe(TF_NewStatus()); + tensorflow::Safe_TF_BufferPtr buf = + tensorflow::make_safe(ProtoStringToTFBuffer(proto.ptr())); + + // Release GIL. + py::gil_scoped_release release; + TF_GraphAddFunctionDef(graph->tf_graph(), buf.get()->data, buf.get()->length, status.get()); + tensorflow::MaybeRaiseRegisteredFromTFStatusWithGIL(status.get()); + }); + + + m.def("TF_GraphCopyFunction", [](PyGraph* graph, const TF_Function* func, const TF_Function* grad) { tensorflow::Safe_TF_StatusPtr status = diff --git a/tensorflow/python/framework/ops.py b/tensorflow/python/framework/ops.py index 34b1eed754bbed..880f802be5862d 100644 --- a/tensorflow/python/framework/ops.py +++ b/tensorflow/python/framework/ops.py @@ -2506,6 +2506,16 @@ def _add_function_recursive(self, function, overwrite=False) -> None: else: self._add_function(f) + def _declare_function_from_op_def(self, op_def) -> None: + + function_def = function_pb2.FunctionDef() + function_def.signature.CopyFrom(op_def) + + with self._c_graph.get() as c_graph: + pywrap_tf_session.TF_GraphAddFunctionDef(c_graph,function_def.SerializeToString()) + + + def _add_function(self, function) -> None: """Adds a function to the graph. @@ -2676,6 +2686,11 @@ def _create_op_internal( input_ops = set(t.op for t in inputs) control_inputs = self._control_dependencies_for_inputs(input_ops) + + if op_def: + self._declare_function_from_op_def(op_def) + + # _create_op_helper mutates the new Operation. `_mutation_lock` ensures a # Session.run call cannot occur between creating and mutating the op. with self._mutation_lock(): From 4992e98d52dcb0cb2b909a0fb497dbda2942774f Mon Sep 17 00:00:00 2001 From: George Vasilakopoulos Date: Sun, 7 Apr 2024 20:31:14 +0000 Subject: [PATCH 09/53] forward declaration changes --- TESTS/test2.py | 24 +++++++++++++++---- tensorflow/core/framework/node_def_builder.cc | 6 ++--- tensorflow/core/graph/node_builder.cc | 11 ++++----- tensorflow/python/framework/ops.py | 4 +++- 4 files changed, 28 insertions(+), 17 deletions(-) diff --git a/TESTS/test2.py b/TESTS/test2.py index fff5ad068bca0a..bb653b4880cadc 100644 --- a/TESTS/test2.py +++ b/TESTS/test2.py @@ -1,8 +1,9 @@ -# import os +import os # os.environ['PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION'] = 'python' - +os.environ['TF_CPP_MAX_VLOG_LEVEL'] = '2' +os.environ['TF_DUMP_GRAPH_NAME_FILTER'] = 'Fac' import tensorflow as tf from tensorflow.python.framework import function @@ -14,9 +15,9 @@ fac = function.Declare("Fac", [("n", tf.int32)], [("ret", tf.int32)]) -# @function.Defun(tf.int32, func_name="Test", out_names=["ret"]) -# def t(n): -# return tf.constant(1) +@function.Defun(tf.int32, func_name="Test", out_names=["ret"]) +def t(n): + return tf.constant(1) @@ -29,8 +30,21 @@ def FacImpl(n): lambda: n * fac(n - 1)) +# @function.Defun(tf.int32, func_name="Fac", out_names=["ret"]) +# def FacImpl2(n): +# return t(1) + FacImpl.add_to_graph(tf.compat.v1.get_default_graph()) +# t.add_to_graph(tf.compat.v1.get_default_graph()) +# FacImpl2.add_to_graph(tf.compat.v1.get_default_graph()) + + +print(tf.compat.v1.get_default_graph().as_graph_def()) + + +# writer = tf.compat.v1.summary.FileWriter('/tensorflow/TESTS/graph', tf.compat.v1.get_default_graph()) +# writer = tf.summary.FileWriter('./graphs', tf.compat.v1.get_default_graph()) with tf.compat.v1.Session() as sess: result = FacImpl(1) diff --git a/tensorflow/core/framework/node_def_builder.cc b/tensorflow/core/framework/node_def_builder.cc index 388591a7d6f18f..fcf73e6970bb5c 100644 --- a/tensorflow/core/framework/node_def_builder.cc +++ b/tensorflow/core/framework/node_def_builder.cc @@ -45,10 +45,8 @@ NodeDefBuilder::NodeDefBuilder(StringPiece name, StringPiece op_name, if (status.ok()) { Initialize(); } else { - // errors_.push_back(std::string(status.message())); - // inputs_specified_ = 0; + errors_.push_back(std::string(status.message())); inputs_specified_ = 0; - node_def_.set_op(string(op_name)); } if (debug != nullptr) MergeDebugInfo(*debug, &node_def_); } @@ -262,7 +260,7 @@ Status NodeDefBuilder::Finalize(NodeDef* node_def, bool consume) { } // Add default values for unspecified attrs. - if(op_def_ != nullptr) AddDefaultsToNodeDef(*op_def_, node_def); + AddDefaultsToNodeDef(*op_def_, node_def); return OkStatus(); } diff --git a/tensorflow/core/graph/node_builder.cc b/tensorflow/core/graph/node_builder.cc index ca32cad845ddc2..526e52b009a240 100644 --- a/tensorflow/core/graph/node_builder.cc +++ b/tensorflow/core/graph/node_builder.cc @@ -133,13 +133,10 @@ Status NodeBuilder::Finalize(Graph* graph, Node** created_node, bool consume) { } NodeDef node_def; - TF_RETURN_IF_ERROR(def_builder_.Finalize(&node_def, consume)); - - if(&def_builder_.op_def() != nullptr){ - TF_RETURN_IF_ERROR(ValidateNodeDef(node_def, def_builder_.op_def())); - TF_RETURN_IF_ERROR( - CheckOpDeprecation(def_builder_.op_def(), graph->versions().producer())); - } + TF_RETURN_IF_ERROR(def_builder_.Finalize(&node_def, consume)); + TF_RETURN_IF_ERROR(ValidateNodeDef(node_def, def_builder_.op_def())); + TF_RETURN_IF_ERROR( + CheckOpDeprecation(def_builder_.op_def(), graph->versions().producer())); TF_ASSIGN_OR_RETURN(Node * node, graph->AddNode(std::move(node_def))); node->set_assigned_device_name(assigned_device_); diff --git a/tensorflow/python/framework/ops.py b/tensorflow/python/framework/ops.py index 880f802be5862d..b863ee2bd878bb 100644 --- a/tensorflow/python/framework/ops.py +++ b/tensorflow/python/framework/ops.py @@ -2512,8 +2512,10 @@ def _declare_function_from_op_def(self, op_def) -> None: function_def.signature.CopyFrom(op_def) with self._c_graph.get() as c_graph: + try: pywrap_tf_session.TF_GraphAddFunctionDef(c_graph,function_def.SerializeToString()) - + except errors.InvalidArgumentError: + pass def _add_function(self, function) -> None: From cd66df8734346277fa626bf0c9ff70b7ea14dba2 Mon Sep 17 00:00:00 2001 From: George Vasilakopoulos Date: Mon, 8 Apr 2024 14:15:34 +0000 Subject: [PATCH 10/53] Transformation seems to be working. Commented out function inlining --- tensorflow/core/framework/function.cc | 12 ++++- .../grappler/optimizers/meta_optimizer.cc | 47 ++++++++++--------- 2 files changed, 35 insertions(+), 24 deletions(-) diff --git a/tensorflow/core/framework/function.cc b/tensorflow/core/framework/function.cc index f3ad99ac3322a8..ee6867cf13efe9 100644 --- a/tensorflow/core/framework/function.cc +++ b/tensorflow/core/framework/function.cc @@ -1437,7 +1437,17 @@ Status FunctionLibraryDefinition::AddFunctionRecord( Status FunctionLibraryDefinition::AddHelper(FunctionRecord* registration, bool* added) { *added = false; - auto iter = records_.find(registration->fdef().signature().name()); + auto iter = records_.find(registration->fdef().signature().name()); + std::ofstream fout("/tensorflow/TESTS/mylogger.txt"); + fout << "Searching for " << registration->fdef().signature().name() << std::endl; + fout << "Size of registry " << records_.size() << std::endl; + for(auto reg : records_){ + fout << "Found op "<< reg.second->fdef().signature().name()<DebugString() << std::endl; + + fout.close(); if (iter != records_.end()) { if (!FunctionDefsEqual(iter->second->fdef(), registration->fdef())) { return errors::InvalidArgument( diff --git a/tensorflow/core/grappler/optimizers/meta_optimizer.cc b/tensorflow/core/grappler/optimizers/meta_optimizer.cc index 5b936537a7510e..88391a2abeb1a5 100644 --- a/tensorflow/core/grappler/optimizers/meta_optimizer.cc +++ b/tensorflow/core/grappler/optimizers/meta_optimizer.cc @@ -108,7 +108,7 @@ int NumIterations(const RewriterConfig& cfg) { bool IsRunOnceOptimizer(const string& name) { return name == "layout" || name == "memory_optimizer" || name == "loop_optimizer" || - absl::StartsWith(name, "auto_mixed_precision"); + absl::StartsWith(name, "auto_mixed_precision") || name == "function_optimizer"; } // Creates a function library stub from a real function library: copy only @@ -214,9 +214,9 @@ std::unique_ptr MetaOptimizer::MakeNewOptimizer( return std::unique_ptr(new ModelPruner()); - // if (LowerControlFlow()) { + if (LowerControlFlow()) { MK_OPT("function_transformation", "function_transformation", new FunctionTransformation()); - // } + } MK_OPT("function", "function_optimization", new FunctionOptimizer(cfg_.function_optimization(), @@ -338,6 +338,14 @@ Status MetaOptimizer::InitializeOptimizers( else optimizers->push_back(std::make_unique()); } + if (BOTH_NOT_OFF(function_transformation)) { + if (USER_IS_EXPERIMENTAL_MLIR(function_transformation) || + USER_IS_EXPERIMENTAL_BOTH(function_transformation)) { + VLOG(2) << "function_transformation is not implemented in TFG yet"; + } else { + optimizers->push_back(std::make_unique()); + } + } if (BOTH_NOT_OFF(function_optimization)) { if (USER_IS_EXPERIMENTAL_MLIR(function_optimization) || USER_IS_EXPERIMENTAL_BOTH(function_optimization)) { @@ -494,14 +502,7 @@ Status MetaOptimizer::InitializeOptimizers( optimizers->push_back( std::make_unique(cfg_.auto_parallel().num_replicas())); } - if (BOTH_NOT_OFF(function_transformation)) { - if (USER_IS_EXPERIMENTAL_MLIR(function_transformation) || - USER_IS_EXPERIMENTAL_BOTH(function_transformation)) { - VLOG(2) << "function_transformation is not implemented in TFG yet"; - } else { - optimizers->push_back(std::make_unique()); - } - } + #ifndef ENABLE_MKL if (BOTH_ARE_ON(scoped_allocator_optimization)) { optimizers->push_back(std::make_unique( @@ -776,12 +777,12 @@ Status MetaOptimizer::OptimizeGraph( Cluster* cluster, GrapplerItem&& item, GraphDef* optimized_graph) { int min_graph_nodes = cfg_.min_graph_nodes() == 0 ? kDefaultMinGraphNodes : cfg_.min_graph_nodes(); - if (item.graph.node_size() < min_graph_nodes) { - VLOG(3) << "Skipping optimization, graph has less than " << min_graph_nodes - << " nodes."; - *optimized_graph = item.graph; - return absl::OkStatus(); - } + // if (item.graph.node_size() < min_graph_nodes) { + // VLOG(3) << "Skipping optimization, graph has less than " << min_graph_nodes + // << " nodes."; + // *optimized_graph = item.graph; + // return OkStatus(); + // } tensorflow::metrics::ScopedCounter<2> timings( tensorflow::metrics::GetGraphOptimizationCounter(), @@ -832,12 +833,12 @@ Status MetaOptimizer::OptimizeGraph( for (int iteration = 0; iteration < NumIterations(cfg_); ++iteration) { // Don't bother optimizing further if the graph is already tiny. - if (optimized_graph->node_size() < min_graph_nodes) { - VLOG(3) << "Stopping after iteration " << iteration - << ", graph is tiny (#nodes = " << optimized_graph->node_size() - << " < " << min_graph_nodes << ")"; - break; - } + // if (optimized_graph->node_size() < min_graph_nodes) { + // VLOG(3) << "Stopping after iteration " << iteration + // << ", graph is tiny (#nodes = " << optimized_graph->node_size() + // << " < " << min_graph_nodes << ")"; + // break; + // } VLOG(4) << "Starting optimization iteration " << iteration; if (VLOG_IS_ON(4)) { From d7994fddd1b83c44d73643bddbd66a1a142acdce Mon Sep 17 00:00:00 2001 From: George Vasilakopoulos Date: Tue, 16 Apr 2024 10:33:09 +0000 Subject: [PATCH 11/53] Distributed test files --- TESTS/distributed/d2.py | 9 +++++++++ TESTS/distributed/d3.py | 9 +++++++++ TESTS/distributed/distr.py | 41 ++++++++++++++++++++++++++++++++++++++ 3 files changed, 59 insertions(+) create mode 100644 TESTS/distributed/d2.py create mode 100644 TESTS/distributed/d3.py create mode 100644 TESTS/distributed/distr.py diff --git a/TESTS/distributed/d2.py b/TESTS/distributed/d2.py new file mode 100644 index 00000000000000..416fd0590503af --- /dev/null +++ b/TESTS/distributed/d2.py @@ -0,0 +1,9 @@ +import tensorflow as tf + + +cluster_spec = { + "worker": ["localhost:2222", "localhost:2223"] + } + +server = tf.distribute.Server(cluster_spec, job_name="worker", task_index=0) +server.join() \ No newline at end of file diff --git a/TESTS/distributed/d3.py b/TESTS/distributed/d3.py new file mode 100644 index 00000000000000..dfd46821e117ce --- /dev/null +++ b/TESTS/distributed/d3.py @@ -0,0 +1,9 @@ +import tensorflow as tf + + +cluster_spec = { + "worker": ["localhost:2222", "localhost:2223"] + } + +server = tf.distribute.Server(cluster_spec, job_name="worker", task_index=1) +server.join() \ No newline at end of file diff --git a/TESTS/distributed/distr.py b/TESTS/distributed/distr.py new file mode 100644 index 00000000000000..87f16acd2bb6cb --- /dev/null +++ b/TESTS/distributed/distr.py @@ -0,0 +1,41 @@ +import tensorflow as tf +from tensorflow.python.framework import function + +tf.compat.v1.disable_eager_execution() + +cluster = tf.train.ClusterSpec({"local": ["localhost:2222", "localhost:2223"]}) + +fib = function.Declare("Fib", [("n", tf.int32)], [("ret", tf.int32)]) + +@function.Defun(tf.int32, func_name="Fib", out_names=["ret"]) +def FibImpl(n): + + def f1(): + with tf.device("/job:local/replica:0/task:0/device:CPU:0"): + ret = tf.constant(1) + return ret + def f2(): + with tf.device("/job:local/replica:0/task:0/device:CPU:0"): + fib1 = fib(n-1) + with tf.device("/job:local/replica:0/task:1/device:CPU:0"): + fib2 = fib(n-2) + + return fib1 + fib2 + + return tf.cond(tf.less_equal(n, 1), f1, f2) + +FibImpl.add_to_graph(tf.compat.v1.get_default_graph()) + +n = tf.constant(1) +x = fib(n) + +res = tf.add(x, 1) + +#print(tf.get_default_graph().as_graph_def()) + +# writer = tf.compat.v1.summary.FileWriter('./graphs', tf.compat.v1.get_default_graph()) + +with tf.compat.v1.Session("grpc://localhost:2222") as sess: + print(sess.run(res)) + +# writer.close() From d60de572fad770fe57b2b7070a5ddb5c4eba2e91 Mon Sep 17 00:00:00 2001 From: George Vasilakopoulos Date: Thu, 18 Apr 2024 13:50:44 +0000 Subject: [PATCH 12/53] Distributed Execution: code migration --- TESTS/distributed/d2.py | 4 +- TESTS/distributed/d3.py | 11 +- TESTS/distributed/distr.py | 31 +- TESTS/test2.py | 4 +- tensorflow/core/BUILD | 2 +- .../distributed_runtime/master_session.cc | 34 + tensorflow/core/graph/graph_partition.cc | 690 +++++++++++++++++- 7 files changed, 758 insertions(+), 18 deletions(-) diff --git a/TESTS/distributed/d2.py b/TESTS/distributed/d2.py index 416fd0590503af..daa16012ea4bca 100644 --- a/TESTS/distributed/d2.py +++ b/TESTS/distributed/d2.py @@ -2,8 +2,8 @@ cluster_spec = { - "worker": ["localhost:2222", "localhost:2223"] + "local": ["172.19.0.3:2222", "172.19.0.2:2223"] } -server = tf.distribute.Server(cluster_spec, job_name="worker", task_index=0) +server = tf.distribute.Server(cluster_spec, job_name="local", task_index=0) server.join() \ No newline at end of file diff --git a/TESTS/distributed/d3.py b/TESTS/distributed/d3.py index dfd46821e117ce..6f1bc6ccec6a68 100644 --- a/TESTS/distributed/d3.py +++ b/TESTS/distributed/d3.py @@ -1,9 +1,14 @@ +import os + + +os.environ['TF_CPP_MAX_VLOG_LEVEL'] = '2' + import tensorflow as tf cluster_spec = { - "worker": ["localhost:2222", "localhost:2223"] + "local": ["172.19.0.3:2222", "172.19.0.2:2223"] } -server = tf.distribute.Server(cluster_spec, job_name="worker", task_index=1) -server.join() \ No newline at end of file +server = tf.distribute.Server(cluster_spec, job_name="local", task_index=1) +server.join() diff --git a/TESTS/distributed/distr.py b/TESTS/distributed/distr.py index 87f16acd2bb6cb..13805180109c25 100644 --- a/TESTS/distributed/distr.py +++ b/TESTS/distributed/distr.py @@ -1,3 +1,8 @@ +import os + + +os.environ['TF_CPP_MAX_VLOG_LEVEL'] = '2' + import tensorflow as tf from tensorflow.python.framework import function @@ -7,6 +12,20 @@ fib = function.Declare("Fib", [("n", tf.int32)], [("ret", tf.int32)]) + + + +@function.Defun(tf.int32, func_name="Minus2", out_names=["ret"]) +def M2(n): + return (n-2) + +@function.Defun(tf.int32, func_name="Minus1", out_names=["ret"]) +def M1(n): + return (n-1) + + + + @function.Defun(tf.int32, func_name="Fib", out_names=["ret"]) def FibImpl(n): @@ -16,9 +35,9 @@ def f1(): return ret def f2(): with tf.device("/job:local/replica:0/task:0/device:CPU:0"): - fib1 = fib(n-1) - with tf.device("/job:local/replica:0/task:1/device:CPU:0"): - fib2 = fib(n-2) + fib1 = M1(n) + with tf.device("/job:local/replica:0/task:0/device:CPU:0"): + fib2 = M2(n) return fib1 + fib2 @@ -26,16 +45,14 @@ def f2(): FibImpl.add_to_graph(tf.compat.v1.get_default_graph()) -n = tf.constant(1) +n = tf.constant(11) x = fib(n) -res = tf.add(x, 1) - #print(tf.get_default_graph().as_graph_def()) # writer = tf.compat.v1.summary.FileWriter('./graphs', tf.compat.v1.get_default_graph()) with tf.compat.v1.Session("grpc://localhost:2222") as sess: - print(sess.run(res)) + print(sess.run(x)) # writer.close() diff --git a/TESTS/test2.py b/TESTS/test2.py index bb653b4880cadc..d27c5f4af1f587 100644 --- a/TESTS/test2.py +++ b/TESTS/test2.py @@ -2,8 +2,8 @@ # os.environ['PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION'] = 'python' -os.environ['TF_CPP_MAX_VLOG_LEVEL'] = '2' -os.environ['TF_DUMP_GRAPH_NAME_FILTER'] = 'Fac' +# os.environ['TF_CPP_MAX_VLOG_LEVEL'] = '2' +# os.environ['TF_DUMP_GRAPH_NAME_FILTER'] = 'Fac' import tensorflow as tf from tensorflow.python.framework import function diff --git a/tensorflow/core/BUILD b/tensorflow/core/BUILD index e2adb15245c183..ef906e41b25579 100644 --- a/tensorflow/core/BUILD +++ b/tensorflow/core/BUILD @@ -1764,7 +1764,7 @@ alias( tf_cuda_library( name = "graph", srcs = ["//tensorflow/core/graph:graph_srcs"], - hdrs = ["//tensorflow/core/graph:graph_headers"], + hdrs = ["//tensorflow/core/graph:graph_headers","//tensorflow/core/common_runtime:graph_constructor.h"], deps = [ ":framework", ":framework_internal", diff --git a/tensorflow/core/distributed_runtime/master_session.cc b/tensorflow/core/distributed_runtime/master_session.cc index 5593963988d9e5..607d80378a06b1 100644 --- a/tensorflow/core/distributed_runtime/master_session.cc +++ b/tensorflow/core/distributed_runtime/master_session.cc @@ -24,6 +24,9 @@ limitations under the License. #include #include +#include "tensorflow/core/util/event.pb.h" +#include "tensorflow/core/util/events_writer.h" + #include "absl/status/status.h" #include "tensorflow/core/common_runtime/process_util.h" #include "tensorflow/core/common_runtime/profile_handler.h" @@ -356,6 +359,37 @@ Status MasterSession::ReffedClientGraph::RegisterPartitions( popts.flib_def = client_graph->flib_def.get(); Status s = DoBuildPartitions(popts, client_graph.get(), &graph_defs); if (s.ok()) { + + printf("\n\n MASTER PARTITIONS:\n"); + int i=0; + for (const auto& it: graph_defs) { + string dvc = it.first; + const GraphDef* graphDef = &it.second; + printf("\n\nDeviceName :'%s'\n", dvc.c_str()); + printf("Partition GraphDef:\n %s\n", SummarizeGraphDef(*graphDef).c_str()); + + string p = strings::StrCat("Partition", i); i++; + EventsWriter writer(p); + Event event; + event.set_wall_time(1234); + event.set_step(34); + + const size_t proto_size = graphDef->ByteSizeLong(); + void* buf = port::Malloc(proto_size); + if (buf == nullptr) { + return errors::ResourceExhausted( + "Failed to allocate memory to serialize message of type '" , + graphDef->GetTypeName(), "' and size ", proto_size); + } + graphDef->SerializeToArray(buf, proto_size); + const void* bf = buf; + event.set_graph_def(bf, proto_size); + writer.WriteEvent(event); + + } + + + // NOTE(mrry): The pointers in `graph_defs_for_publishing` do not remain // valid after the call to DoRegisterPartitions begins, so // `stats_publisher_` must make a copy if it wants to retain the diff --git a/tensorflow/core/graph/graph_partition.cc b/tensorflow/core/graph/graph_partition.cc index eac0fa367e5577..e899db008e740d 100644 --- a/tensorflow/core/graph/graph_partition.cc +++ b/tensorflow/core/graph/graph_partition.cc @@ -13,8 +13,6 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/core/graph/graph_partition.h" - #include #include #include @@ -43,6 +41,12 @@ limitations under the License. #include "tensorflow/core/util/device_name_utils.h" #include "tensorflow/core/util/dump_graph.h" +#include "tensorflow/core/common_runtime/graph_constructor.h" +#include "tensorflow/core/graph/graph_partition.h" +#include "tensorflow/core/util/event.pb.h" +#include "tensorflow/core/util/events_writer.h" +#include "tensorflow/core/framework/graph_def_util.h" + namespace tensorflow { namespace { @@ -968,6 +972,647 @@ void SetIncarnation(const PartitionOptions& opts, GraphDef* gdef) { } } +/**************************************************************************************************/ + +struct StateMachineNodeInput { + string src; + int index; +}; + +struct StateMachineParent { + Node* parent_node; + int parent_index; +}; + +struct StateMachineNode { + Node* node; + std::vector inputs; +}; + +struct StateMachineGraph { + std::unordered_map nodes; + std::set depends_on; + Node* merge; +}; + +struct StateMachine { + // A map from unique_ids to StateMachineGraphs representing a general dynamic + // state machine that we update every time a function gets called, and helps us + // gradually build the state machines of the partitions + std::unordered_map state_machine_graphs; + // state_machine_parents is the 'spine' of the graph, + // containing only control flow nodes + std::vector state_machine_parents; + + std::unordered_map switches_info; + // + std::unordered_map switchToPred; + + string leader_partition; + + // Maps device names to smaller strings + std::unordered_map device_names_map; + + std::unordered_map*> partitionsToSMG; +}; + +struct FuncInfo { + // A map from to the num of function's arguments + std::unordered_map funcInputs; + // Helps us seperate functions with same frame_name but + // different non recursive call sites + std::unordered_map funcVisitedCounter; + // Εach vector below operates as a barrier, + // we don't call CallingFunction(..) before we gather + // all function's arguments/calls first + std::unordered_map*> funcCalls; +}; + +// Adds root nodes into ready_nodes queue and sets ready_inputs appropriately +Status PreprocessGraph(std::unordered_map &ready_inputs, Graph* g, + std::deque &ready_nodes) { + + + std::unordered_map> returning_nodes; + + for (Node* node : g->nodes()) { + + if (node->in_edges().empty()) { + ready_nodes.push_back(node); + } + bool recursion_merge = 0; + if (IsMerge(node)) { + ready_inputs[node] = 0; + for (const Edge* in_edge : node->in_edges()) { + + Node* in = in_edge->src(); + // if (IsNextIteration(*output_map.GetNode(input))) { + // ready_inputs[node]++; + // } + if (IsCall(in)) { + ready_inputs[node]++; + recursion_merge = 1; + } + } + if (recursion_merge) { + ready_inputs[node]--; + recursion_merge = 0; + } + + } else if (IsReturn(node)) { + + for (const Edge* in_edge : node->in_edges()) { + Node* in = in_edge->src(); + + if (!in_edge->IsControlEdge()) { + int call_id; + TF_RETURN_IF_ERROR(GetNodeAttr(node->attrs(), "call_id", &call_id)); + returning_nodes[in].emplace(call_id); + } + } + ready_inputs[node] = 0; + + } else { + ready_inputs[node] = 0; + } + } + + for (const auto& retnode : returning_nodes) { + if (retnode.second.size() > 1) { + // Detected Cycle + ready_inputs[retnode.first]++; + } + } + + return OkStatus(); +} + +string GetDeviceMappedName(StateMachine &state_machine, string device_name) { + + std::unordered_map& device_map = state_machine.device_names_map; + + auto slot = &device_map[device_name]; + if (*slot == "") + *slot = strings::StrCat("_p", device_map.size() + 1); + return *slot; +} + +bool IsCallSuccessor(Node* node) { + + for (const Edge* in_edge : node->in_edges()) { + Node* src = in_edge->src(); + if (IsCall(src) && !in_edge->IsControlEdge()) + return true; + } + return false; +} + +void DeleteStateMachineGraph(StateMachine& state_machine, string unique_id) { + + StateMachineGraph *smg = state_machine.state_machine_graphs[unique_id]; + + for (auto& it : smg->nodes) + delete it.second; + delete smg; +} + +std::vector* GetOrCreateCalls(int call_id, std::unordered_map*> &funcCalls) { + auto slot = &funcCalls[call_id]; + if (*slot == nullptr) + *slot = new std::vector; + return *slot; +} + +std::set* GetOrCreatePartition(string partition, std::unordered_map*> &partsTpSmg) { + auto slot = &partsTpSmg[partition]; + if (*slot == nullptr) + *slot = new std::set; + return *slot; +} + +// For one if-else construction there are more than one Switch nodes guarding all the inputs +// that are needed inside the branches but live outside of them. We need to collect all the Switch +// nodes that correspond to one if-else construction and treat them as one in the state machines +// switches_info: Every switch node maps to the original switch that we "ll take into account +void CollectSwitches(Graph* g, StateMachine& state_machine) { + + std::unordered_map pred_switch; + + for (Node *node : g->nodes()) { + + if (IsSwitch(node)) { + + for (const Edge *in_edge : node->in_edges()) { + + int port = in_edge->dst_input(); + + // A sloppy way to determine if this is the predicate input + if (!in_edge->IsControlEdge() && port == 1) { + + Node *predicate = in_edge->src(); + + while (IsIdentity(predicate)) { + for (const Edge *inEdge : predicate->in_edges()) { + if (!inEdge->IsControlEdge()) { + predicate = inEdge->src(); + break; + } + } + } + + // We 've got the real predicate + Node *switchNode; + if (pred_switch.find(predicate) == pred_switch.end()) { + // Original switch + pred_switch[predicate] = node; + state_machine.switchToPred[node] = predicate; + switchNode = node; + } else { + // "Synonym" switch + switchNode = pred_switch[predicate]; + } + + state_machine.switches_info[node] = switchNode; + + break; + } + } + printf("Switch : %s -> %s\n", node->name().c_str(), state_machine.switches_info[node]->name().c_str()); + } + } + + printf("\n\n\n"); +} + +void GatherPartitionStateMachines(StateMachine& state_machine, std::set* smgs) { + + std::deque queue; + + for (auto& it : *smgs) + queue.push_back(it); + + while (!queue.empty()) { + string smg = queue.front(); + queue.pop_front(); + + StateMachineGraph* sm_graph = state_machine.state_machine_graphs[smg]; + for (auto& it : sm_graph->depends_on) { + // If not already visited + if (smgs->find(it) == smgs->end()) { + smgs->emplace(it); + queue.push_back(it); + } + } + } +} + +NodeDef* FindNodeInGraphDef(GraphDef& graphDef, string node_name) { + + for (NodeDef& nodeDef : *graphDef.mutable_node()) { + if (nodeDef.name() == node_name) + return &nodeDef; + } + return nullptr; +} + +void ConnectMergeToNode(GraphDef& graphDef, string merge_name, string node_name, + StateMachine& state_machine, string partition_name) { + + // We can safely infer the correct Merge's name and add it as control input to the node + // even though partition state machine's Merge has not already been added into graphdef + string suffix; + (partition_name != state_machine.leader_partition) ? + (suffix = GetDeviceMappedName(state_machine, partition_name)) : (suffix = ""); + + //Add as control input + NodeDef* node = FindNodeInGraphDef(graphDef, node_name); + *node->add_input() = strings::StrCat("^", merge_name, suffix); +} + +void AddPartitionStateMachine(StateMachine& state_machine, GraphDef& main_graphDef, + string unique_id, string partition) { + + StateMachineGraph *sm_graph = state_machine.state_machine_graphs[unique_id]; + string suffix = GetDeviceMappedName(state_machine, partition); + for (const auto &it : sm_graph->nodes) { + string node_name = it.first; + StateMachineNode *sm_node = it.second; + Node *node = sm_node->node; + + // Build NodeDef + NodeDef *nodedef = main_graphDef.add_node(); + //Note: suffix does not guarantee that name is unique + nodedef->set_name(strings::StrCat(node_name, suffix)); + nodedef->set_op(node->op_def().name()); + nodedef->set_device(partition); + + // Add Inputs + for (int i = 0; i < sm_node->inputs.size(); ++i) { + // There won't exist any control inputs here + nodedef->add_input(strings::StrCat(sm_node->inputs[i].src, suffix, ":", sm_node->inputs[i].index)); + + if (absl::StartsWith(StringPiece(sm_node->inputs[i].src),"Dummy_")) { + Tensor tensor(DT_INT32, TensorShape({0})); + NodeDef* dummy = main_graphDef.add_node(); + dummy->set_name(strings::StrCat(sm_node->inputs[i].src, suffix)); + dummy->set_op("Const"); + dummy->set_device(partition); + AddNodeAttr("dtype", DT_INT32, dummy); + AddNodeAttr("value", tensor, dummy); + } + } + + if (IsSwitch(node)) { + // Add predicate input too + nodedef->add_input(state_machine.switchToPred[node]->name()); + // Add control input from partition's Merge to partition's Switch + nodedef->add_input(strings::StrCat("^", sm_graph->merge->name(), suffix)); + } + + for (const auto &itt : node->def().attr()) { + // Not sure if this is copying attrs correctly + if (itt.first == "T") { + // We don't care about keeping the original "T" attr + // in state machine nodes + AddNodeAttr(itt.first, DT_INT32, nodedef); + } else + AddNodeAttr(itt.first, itt.second, nodedef); + } + } +} + +Status AddNodeToStateMachine(StateMachine& state_machine, string unique_id, Node* node, bool cycle) { + + StateMachineGraph *smg = state_machine.state_machine_graphs[unique_id]; + StateMachineNode *smn = new StateMachineNode; + + smn->node = node; + + StateMachineParent *parent = &state_machine.state_machine_parents[node->id()]; + + if (parent->parent_node == nullptr) { + int call_id; + TF_RETURN_IF_ERROR(GetNodeAttr(node->attrs(), "call_id", &call_id)); + smn->inputs.push_back({strings::StrCat("Dummy_", call_id), 0}); + } else + smn->inputs.push_back({parent->parent_node->name(), parent->parent_index}); + + smg->nodes[node->name()] = smn; + + // If cycle is true, node is a recursive call, that needs to be added as + // input to the corresponding Merge node + if (cycle) { + // We traverse graph the way topological sort does, so we will never + // meet a recursive call node before its corresponding Merge + StateMachineNode* merge = smg->nodes[smg->merge->name()]; + merge->inputs.push_back({node->name(), 0}); + } + + return OkStatus(); +} + +Status CallingFunction(Graph* graph, GraphDef& main_graphDef, StateMachine& state_machine, FuncInfo& funcInfo, + string function_frame_name, int function_call_id, + std::unordered_map& ready_inputs, + std::deque& prev_ready_nodes) { + + Node *merge, *call; + std::deque ready_nodes; + + string function_unique_id = strings::StrCat(function_frame_name, ":", + funcInfo.funcVisitedCounter[function_frame_name]); + + std::vector* calls = funcInfo.funcCalls[function_call_id]; + for (int i=0; i < calls->size(); ++i) { + ready_nodes.push_back((*calls)[i]); + } + call = (*calls)[0]; + + // We add only one Call node for all possible function's args in the state machine + TF_RETURN_IF_ERROR(AddNodeToStateMachine(state_machine, function_unique_id, call, false)); + + std::vector& state_machine_parents = state_machine.state_machine_parents; + StateMachineGraph* sm_graph = state_machine.state_machine_graphs[function_unique_id]; + + // Call's successor (the non control output) will be either + // a Merge node (in case of recursion) or an Identity node. + // Either way we add that successor to the state machine, too. + // Same as above, we add only one Merge node instead of one per function's arg + for (const Edge* out_edge : call->out_edges()) { + if (!out_edge->IsControlEdge()) { + merge = out_edge->dst(); + state_machine_parents[merge->id()].parent_node = call; + state_machine_parents[merge->id()].parent_index = 0; + TF_RETURN_IF_ERROR(AddNodeToStateMachine(state_machine, function_unique_id, merge, false)); + sm_graph->merge = merge; + break; + } + } + + while (!ready_nodes.empty()) { + + Node* ready_node = ready_nodes.front(); + ready_nodes.pop_front(); + + int parent_index = 0; + Node* parent = state_machine_parents[ready_node->id()].parent_node; + + // The ops below need to update the parent + if (IsCall(ready_node)) { + parent = call; + } else if (IsCallSuccessor(ready_node)) { + parent = merge; + } else if (IsSwitch(ready_node)) { + Node *sw = state_machine.switches_info[ready_node]; + if (sw == ready_node) + TF_RETURN_IF_ERROR(AddNodeToStateMachine(state_machine, function_unique_id, ready_node, false)); + parent = sw; + } else if (IsMerge(ready_node)) { + // Control Flow (regular) Merge has a corresponding Switch node + // Parent gets the value of that switch node's parent + parent = state_machine_parents[parent->id()].parent_node; + parent_index = state_machine_parents[parent->id()].parent_index; + } else if (IsReturn(ready_node)) { + // Return needs to propagate its corresponding Call's parent to all its successors + for (const Edge* in_edge : ready_node->in_edges()) { + if (in_edge->IsControlEdge()) { + Node* call_node = in_edge->src(); + parent = state_machine_parents[call_node->id()].parent_node; + parent_index = state_machine_parents[call_node->id()].parent_index; + break; + } + } + int call_id; + TF_RETURN_IF_ERROR(GetNodeAttr(ready_node->attrs(), "call_id", &call_id)); + // If not a 'recursive' return + if (call_id == function_call_id) { + // Add the successors of Return node to prev_ready_nodes queue + prev_ready_nodes.push_back(ready_node); + // Set the parent value of the only actual output of return + for (const Edge* out_edge : ready_node->out_edges()) { + Node* out = out_edge->dst(); + state_machine_parents[out->id()].parent_node = parent; + state_machine_parents[out->id()].parent_index = parent_index; + break; + } + continue; + } + } + + // Process ready_node's outputs + for (const Edge* out_edge : ready_node->out_edges()) { + Node* out = out_edge->dst(); + + ready_inputs[out]++; + + // For a cross-device edge, on the dst device, add a control edge + // from the merge node of the state machine to dst. If a send/recv is + // introduced for this edge in future partitioning, we delete this + // control edge and add a new control edge from the merge to the recv. + const string& src_device = ready_node->assigned_device_name(); + const string& dst_device = out->assigned_device_name(); + if (src_device != dst_device) { + if (IsCallSuccessor(ready_node) && IsConstant(out)) { + // Remove this control edge that ensures constant executes in the same frame, + // and add a new one from the Constant's partition's state machine merge to the constant + NodeDef* con_node = FindNodeInGraphDef(main_graphDef, out->name()); + for (string& input : *con_node->mutable_input()) { + if (absl::StartsWith(StringPiece(input),strings::StrCat("^", ready_node->name()))) { + string suffix = GetDeviceMappedName(state_machine, dst_device); + input = strings::StrCat("^", merge->name(), suffix); + break; + } + } + } else + ConnectMergeToNode(main_graphDef, merge->name(), out->name(), state_machine, dst_device); + } + + if (ready_inputs[out] == out->in_edges().size()) { + + if (IsSwitch(ready_node)) { + // We need to fix parent_index appropriately + parent_index = out_edge->src_output(); + } + + // Set node's parent + state_machine_parents[out->id()].parent_node = parent; + state_machine_parents[out->id()].parent_index = parent_index; + + std::unordered_map& sm_graphs = state_machine.state_machine_graphs; + + if (IsCall(out)) { + + string frame_name; + TF_RETURN_IF_ERROR(GetNodeAttr(out->attrs(), "frame_name", &frame_name)); + int call_id; + TF_RETURN_IF_ERROR(GetNodeAttr(out->attrs(), "call_id", &call_id)); + + std::vector* calls = GetOrCreateCalls(call_id, funcInfo.funcCalls); + calls->push_back(out); + + if (funcInfo.funcInputs[frame_name] == calls->size()) { + + // We gathered all function's inputs + + string unique_id = strings::StrCat(frame_name, ":", funcInfo.funcVisitedCounter[frame_name]); + + if (sm_graphs.find(unique_id) == sm_graphs.end()) { + + sm_graphs.emplace(unique_id, new StateMachineGraph); + TF_RETURN_IF_ERROR(CallingFunction(graph, main_graphDef, state_machine, funcInfo, frame_name, call_id, ready_inputs, ready_nodes)); + funcInfo.funcVisitedCounter[frame_name]++; + } else { + // Recursive Call (either to the same function or another one (mutual recursion) + TF_RETURN_IF_ERROR(AddNodeToStateMachine(state_machine, unique_id, (*calls)[0], true)); + // Add the recursive call nodes to ready_nodes + for (int i=0; i < calls->size(); ++i) + ready_nodes.push_back((*calls)[i]); + } + + sm_graphs[unique_id]->depends_on.emplace(function_unique_id); + } + } else { + GetOrCreatePartition(dst_device, state_machine.partitionsToSMG)->emplace(function_unique_id); + ready_nodes.push_back(out); + } + } + } + } + + return OkStatus(); +} + +Status AddFunctionStateMachines(const PartitionOptions& opts, + Graph* g, GraphDef& main_graphDef, GraphInfo* g_info) { + + Status status; + GraphDefBuilder::Options bopts(g, &status); + + FuncInfo funcInfo; + int nodes_num = g->num_node_ids(); + + const FunctionDefLibrary& fdef = opts.flib_def->ToProto(); + for (const FunctionDef& func : fdef.function()) { + + int num_inputs = func.signature().input_arg_size(); + string name = func.signature().name(); + funcInfo.funcInputs[name] = num_inputs; + funcInfo.funcVisitedCounter[name] = 0; + } + + StateMachine state_machine; + state_machine.state_machine_parents.resize(nodes_num); + + CollectSwitches(g, state_machine); + + // Add all state machines for cross-device frames. + // A state machine is added only when there is a cross-device edge in a + // non-root frame. + + // Visit nodes the way topological sort does + std::deque ready_nodes; + std::unordered_map ready_inputs; + + TF_RETURN_IF_ERROR(PreprocessGraph(ready_inputs, g, ready_nodes)); + + // We convert graph to its equivalent graph_def, cause it's easier + // to extend it with the GraphDef state machines of partitions + g->ToGraphDef(&main_graphDef); + + while (!ready_nodes.empty()) { + Node* ready_node = ready_nodes.front(); + ready_nodes.pop_front(); + + for (const Edge* out_edge : ready_node->out_edges()) { + Node* out = out_edge->dst(); + + ready_inputs[out]++; + + if (ready_inputs[out] == out->in_edges().size()) { + + if (IsCall(out)) { + string frame_name; + TF_RETURN_IF_ERROR(GetNodeAttr(out->attrs(), "frame_name", &frame_name)); + int call_id; + TF_RETURN_IF_ERROR(GetNodeAttr(out->attrs(), "call_id", &call_id)); + + std::vector* calls = GetOrCreateCalls(call_id, funcInfo.funcCalls); + calls->push_back(out); + + if (funcInfo.funcInputs[frame_name] == calls->size()) { + + string unique_id = strings::StrCat(frame_name, ":", funcInfo.funcVisitedCounter[frame_name]); + + // We gathered all function's inputs + state_machine.leader_partition = out->assigned_device_name(); + state_machine.state_machine_graphs.emplace(unique_id, new StateMachineGraph); + TF_RETURN_IF_ERROR(CallingFunction(g, main_graphDef, state_machine, funcInfo, frame_name, call_id, ready_inputs, ready_nodes)); + funcInfo.funcVisitedCounter[frame_name]++; + + // Adding partition state machines to graph + for (auto& it: state_machine.partitionsToSMG) { + string partition = it.first; + + // Leader Partition already has its state machine + if (partition == state_machine.leader_partition) + continue; + + std::set* smgs = it.second; + + // Collect all the state machine graphs that smgs depened on + GatherPartitionStateMachines(state_machine, smgs); + + for (auto& it : *smgs) + AddPartitionStateMachine(state_machine, main_graphDef, it, partition); + } + + // Deallocate space + for (auto& it : state_machine.partitionsToSMG) + delete it.second; + state_machine.partitionsToSMG.clear(); + + for (auto& it: state_machine.state_machine_graphs) + DeleteStateMachineGraph(state_machine, it.first); + state_machine.state_machine_graphs.clear(); + } + } else + ready_nodes.push_back(out); + } + } + } + + // Deallocate space + for (auto& it : funcInfo.funcCalls) + delete it.second; + +/****************************************************************************/ + printf("\n\nSummarize Main Graph\n %s\n", SummarizeGraphDef(main_graphDef).c_str()); + // Write an event, so that we can visualize this optimized graph in tensorboard + EventsWriter writer("Full_Partitioned"); + Event event; + event.set_wall_time(1234); + event.set_step(34); + + const size_t proto_size = main_graphDef.ByteSizeLong(); + void* buf = port::Malloc(proto_size); + if (buf == nullptr) { + return errors::ResourceExhausted( + "Failed to allocate memory to serialize message of type '" , + main_graphDef.GetTypeName(), "' and size ", proto_size); + } + main_graphDef.SerializeToArray(buf, proto_size); + const void* bf = buf; + event.set_graph_def(bf, proto_size); + writer.WriteEvent(event); +/****************************************************************************/ + + return OkStatus(); +} + + + +/**************************************************************************************************/ + Status Partition(const PartitionOptions& opts, Graph* g, std::unordered_map* partitions) { // TODO(b/290689453) Refactor this into smaller functions @@ -977,14 +1622,40 @@ Status Partition(const PartitionOptions& opts, Graph* g, partitions->clear(); GraphInfo g_info; + std::unique_ptr new_g(new Graph(OpRegistry::Global())); if (!opts.control_flow_added) { // Add the "code" for distributed execution of control flow. Code is // added only for the frames that are placed on multiple devices. The // new graph is an equivalent transformation of the original graph and // has the property that it can be subsequently partitioned arbitrarily // (down to the level of individual device) for distributed execution. + + GraphDef main_graphDef; + g->ToGraphDef(&main_graphDef); + printf("\n\nSummarize Main Graph:\n %s\n\n", SummarizeGraphDef(main_graphDef).c_str()); + status = AddControlFlow(opts, g, &g_info); if (!status.ok()) return status; + + GraphDef gdef; + status = AddFunctionStateMachines(opts, g, gdef, &g_info); + if (status.ok()) { + // Convert GraphDef back to Graph so it can be partitioned + GraphConstructorOptions gopts; + gopts.allow_internal_ops = true; + TF_RETURN_IF_ERROR( + ConvertGraphDefToGraph(gopts, gdef, new_g.get())); + g = new_g.get(); + + // The graph conversion sets the requested device names but not the assigned + // device names. However, since at this point the graph is placed TF expects + // an assigned device name for every node. Therefore we copy the requested + // device into the assigned device field. + for (Node* node : g->nodes()) { + node->set_assigned_device_name(node->requested_device()); + } + } else return status; + } // At this point, all the graph mutations have been done. Build memory @@ -1058,7 +1729,20 @@ Status Partition(const PartitionOptions& opts, Graph* g, int32_t num_input_edges = 0; for (const Edge* edge : dst->in_edges()) { if (edge->IsControlEdge()) { - if (IsMerge(edge->src()) && IsControlLoop(edge->src())) { + if ((IsMerge(edge->src()) && IsControlLoop(edge->src())) || + (IsCallSuccessor(edge->src()) && (!IsConstant(edge->dst()) || + edge->dst()->in_edges().size() > 1))) { + // Note: not all control edges are control flow edges. + // There are also control edges added in + // FunctionTransformation for ensuring that Constants will execute in the + // correct 'frame'. + // We made sure in AddFunctionsStateMachines that: + // if a Constant in partition A, has such incoming edge from a CallSuccessor(..) + // node, then this node will definitely belong in the same A partition, so we + // can safely add those edges in "inputs" as we do with common control edges. + // All the other edges whose src node is a CallSuccessor node are control flow edges. + + // This is one of the control edges added for control flow. There // can be multiple such edges as the dest node may have multiple // remote inputs. We keep track of the number of such edges. From 4aab4ea70b36be73465a7a27fab4bef50450d9ee Mon Sep 17 00:00:00 2001 From: George Vasilakopoulos Date: Tue, 23 Apr 2024 17:33:40 +0000 Subject: [PATCH 13/53] Instructions on how to built tf in a container --- recipe.txt | 155 +++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 155 insertions(+) create mode 100644 recipe.txt diff --git a/recipe.txt b/recipe.txt new file mode 100644 index 00000000000000..7f581ee9fcde91 --- /dev/null +++ b/recipe.txt @@ -0,0 +1,155 @@ +========================================= IN A NEW NODE ================================ + +1. Clone your repo + +2. Install Docker + +4. Install VS code extensions etc + +5. Create Docker image using this Dockerfile + +########### Dockerfile for my own project ################ + +FROM tensorflow/tensorflow:latest + +RUN rm -rf /tensorflow +COPY ./tensorflow /tensorflow + +RUN apt-get update && apt-get install clang -y \ + && apt-get install -y gdb \ + && apt-get install -y patchelf + && apt-get install -y git + +RUN echo "deb [arch=amd64] https://storage.googleapis.com/bazel-apt stable jdk1.8" | tee /etc/apt/sources.list.d/bazel.list > /dev/null + +RUN curl https://bazel.build/bazel-release.pub.gpg | apt-key add - +RUN apt update +RUN apt install -y bazel-6.5.0 +RUN apt install -y bazel +WORKDIR /tensorflow + +RUN bazel build --config=dbg //tensorflow/tools/pip_package:build_pip_package +# RUN ./bazel-bin/tensorflow/tools/pip_package/build_pip_package /mnt + +############################################################# + +6. Create a container like this + + +###### +docker run -d --restart always -it --cap-add=SYS_PTRACE --security-opt seccomp=unconfined -w /tensorflow -v $PWD:/mnt \ + -e HOST_PERMS="\\((id -u):\\)(id -g)" my_image bash +###### + +======================================================================================= + +CREATING KELLY'S IMAGE + +1. Clone her repo + +2. Build docker image using the following Dockerfile: + + +# docker build -t kelly_image . + +############# Dockerfile for Kelly's project ################ + +FROM tensorflow/tensorflow:1.4.0 + +RUN rm -rf /tensorflow +COPY ./tensorflow /tensorflow + + + +RUN apt-get update \ + && apt-get install -y curl wget \ + && apt-get install -y software-properties-common \ + && apt-get install -y unzip \ + && apt-get install -y git \ + && apt-get install -y gcc g++ \ + && apt-get install -y gdb +########################################################## + + +4. RUN + +# apt-get upgrade + + +5. Install Conda and create virtual environment: + +# cd +# wget https://repo.anaconda.com/archive/Anaconda3-2022.05-Linux-x86_64.sh +# chmod +x Anaconda3-2022.05-Linux-x86_64.sh +# ./Anaconda3-2022.05-Linux-x86_64.sh + +... + +# conda create -n venv pip python=3.7 +# conda activate venv + +6. Install some stuff: + +pip install -U --user pip six numpy wheel setuptools mock future>=0.17.1 +pip install -U --user keras_applications==1.0.6 --no-deps +pip install -U --user keras_preprocessing==1.0.5 --no-deps + + +7. Install bazel: + +# cd +# wget https://raw.githubusercontent.com/acharal/tensorflow/recursive-functions/tensorflow/tools/ci_build/install/install_bazel.sh + +# chmod +x install_bazel.sh +# ./install_bazel.sh + + +8. In tensorflow/workspace.bzl change the installation of cython to + + + +############## +native.new_http_archive( + name = "cython", + sha256 = "94916d1ede67682638d3cc0feb10648ff14dc51fb7a7f147f4fedce78eaaea97", + urls = [ + "https://files.pythonhosted.org/packages/f0/66/6309291b19b498b672817bd237caec787d1b18013ee659f17b1ec5844887/Cython-0.29.tar.gz", + ], + strip_prefix = "Cython-0.29", + build_file = str(Label("//third_party:cython.BUILD")), + ) +############## + + +9. Build Tensorflow as follows: + + +# ./configure +# bazel build --config=opt //tensorflow/tools/pip_package:build_pip_package --cxxopt="-g" --cxxopt="-D_GLIBCXX_USE_CXX11_ABI=0" --cxxopt="-fpermissive" +# bazel-bin/tensorflow/tools/pip_package/build_pip_package /tmp/tensorflow_pkg/ +# pip3 uninstall -y tensorflow +# pip3 install /tmp/tensorflow_pkg/tensorflow-1.4.2-cp37-cp37m-linux_x86_64.whl + + + + + + +Comments: + +// RUN apt-get install -y software-properties-common +// RUN apt-get install unzip +// RUN apt-get update +// RUN add-apt-repository -y ppa:ubuntu-toolchain-r/test +// // RUN ./bazel-bin/tensorflow/tools/pip_package/build_pip_package /mnt +// RUN apt-get install -y gcc-11 g++-11 + +// RUN update-alternatives --install /usr/bin/gcc gcc /usr/bin/gcc-11 60 --slave /usr/bin/g++ g++ /usr/bin/g++-11 +// RUN wget https://repo.anaconda.com/archive/Anaconda3-2022.05-Linux-x86_64.sh +// RUN chmod +x Anaconda3-2022.05-Linux-x86_64.sh + +// // RUN conda create -n venv pip python=3.7 +// RUN wget https://raw.githubusercontent.com/acharal/tensorflow/recursive-functions/tensorflow/tools/ci_build/install/install_bazel.sh + + + From 98aa2747f400fa710898f8682b633cd908b1aaec Mon Sep 17 00:00:00 2001 From: Calliope Kostopoulou Date: Thu, 25 Apr 2024 07:50:05 +0000 Subject: [PATCH 14/53] minimizing changes --- .../core/revived_types/flat_tensor_function.cc | 2 -- .../core/common_runtime/propagator_state.h | 16 +++++++--------- 2 files changed, 7 insertions(+), 11 deletions(-) diff --git a/tensorflow/c/experimental/saved_model/core/revived_types/flat_tensor_function.cc b/tensorflow/c/experimental/saved_model/core/revived_types/flat_tensor_function.cc index ec8d4ad31d0f44..d6e568090f7d27 100644 --- a/tensorflow/c/experimental/saved_model/core/revived_types/flat_tensor_function.cc +++ b/tensorflow/c/experimental/saved_model/core/revived_types/flat_tensor_function.cc @@ -13,8 +13,6 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include - #include "tensorflow/c/experimental/saved_model/core/revived_types/flat_tensor_function.h" #include diff --git a/tensorflow/core/common_runtime/propagator_state.h b/tensorflow/core/common_runtime/propagator_state.h index 0b5d2329ac49df..79b3ed9d8e31ae 100644 --- a/tensorflow/core/common_runtime/propagator_state.h +++ b/tensorflow/core/common_runtime/propagator_state.h @@ -283,15 +283,6 @@ class PropagatorState { // The number of outstanding iterations. int num_outstanding_iterations TF_GUARDED_BY(mu) = 1; - // Mapping from frame ID to outstanding frames. A new frame is created - // at some iteration of an active frame. So the unique key for the new - // child frame is a hash composed of the ID of the parent frame, the iteration - // number at which the parent frame is creating the new frame, and the - // name of the new frame from nodedef. - absl::flat_hash_map outstanding_child_frames_ - TF_GUARDED_BY(mu); - - private: // The active iteration states of this frame. gtl::InlinedVector iterations; @@ -549,6 +540,13 @@ class PropagatorState { // The root frame in which the execution of this step is started. FrameState* root_frame_; + // Mapping from frame ID to outstanding frames. A new frame is created + // at some iteration of an active frame. So the unique key for the new + // child frame is a hash composed of the ID of the parent frame, the iteration + // number at which the parent frame is creating the new frame, and the + // name of the new frame from nodedef. + absl::flat_hash_map outstanding_frames_ + TF_GUARDED_BY(mu_); PropagatorState(const PropagatorState&) = delete; void operator=(const PropagatorState&) = delete; From 5495dd5231b249874fed6fa605a7ee5f1bff1e81 Mon Sep 17 00:00:00 2001 From: Calliope Kostopoulou Date: Thu, 25 Apr 2024 08:00:55 +0000 Subject: [PATCH 15/53] minimizing changes --- tensorflow/core/framework/function.cc | 18 ------------------ 1 file changed, 18 deletions(-) diff --git a/tensorflow/core/framework/function.cc b/tensorflow/core/framework/function.cc index ee6867cf13efe9..3dc15a6842a065 100644 --- a/tensorflow/core/framework/function.cc +++ b/tensorflow/core/framework/function.cc @@ -16,7 +16,6 @@ limitations under the License. #include "tensorflow/core/framework/function.h" #include -#include #include #include #include @@ -1438,16 +1437,6 @@ Status FunctionLibraryDefinition::AddHelper(FunctionRecord* registration, bool* added) { *added = false; auto iter = records_.find(registration->fdef().signature().name()); - std::ofstream fout("/tensorflow/TESTS/mylogger.txt"); - fout << "Searching for " << registration->fdef().signature().name() << std::endl; - fout << "Size of registry " << records_.size() << std::endl; - for(auto reg : records_){ - fout << "Found op "<< reg.second->fdef().signature().name()<DebugString() << std::endl; - - fout.close(); if (iter != records_.end()) { if (!FunctionDefsEqual(iter->second->fdef(), registration->fdef())) { return errors::InvalidArgument( @@ -1726,13 +1715,6 @@ Status FunctionLibraryDefinition::LookUp( const string& op, const OpRegistrationData** op_reg_data) const { tf_shared_lock l(mu_); auto iter = records_.find(op); - std::ofstream outputFile("/tensorflow/TESTS/mylog.txt", std::ios::app); - outputFile << "Searching for " << op << std::endl; - for(auto s : ListFunctionNames()){ - outputFile << "Function Name: " << s << std::endl; - } - - outputFile.close(); if (iter != records_.end()) { *op_reg_data = &iter->second->op_registration_data(); return OkStatus(); From 09eaaf20eda9bfb08928f3c1e36bdfdff9a4ae78 Mon Sep 17 00:00:00 2001 From: Calliope Kostopoulou Date: Thu, 25 Apr 2024 08:01:43 +0000 Subject: [PATCH 16/53] minimizing changes --- tensorflow/core/framework/function.cc | 1 + 1 file changed, 1 insertion(+) diff --git a/tensorflow/core/framework/function.cc b/tensorflow/core/framework/function.cc index 3dc15a6842a065..28fd4e304d019a 100644 --- a/tensorflow/core/framework/function.cc +++ b/tensorflow/core/framework/function.cc @@ -16,6 +16,7 @@ limitations under the License. #include "tensorflow/core/framework/function.h" #include + #include #include #include From eb97c749e7ba39e3cd1c9f518a8d35a8b24bccb9 Mon Sep 17 00:00:00 2001 From: Calliope Kostopoulou Date: Thu, 25 Apr 2024 08:02:49 +0000 Subject: [PATCH 17/53] minimizing changes --- tensorflow/core/framework/function.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow/core/framework/function.cc b/tensorflow/core/framework/function.cc index 28fd4e304d019a..0b6bacd94af0d9 100644 --- a/tensorflow/core/framework/function.cc +++ b/tensorflow/core/framework/function.cc @@ -1437,7 +1437,7 @@ Status FunctionLibraryDefinition::AddFunctionRecord( Status FunctionLibraryDefinition::AddHelper(FunctionRecord* registration, bool* added) { *added = false; - auto iter = records_.find(registration->fdef().signature().name()); + auto iter = records_.find(registration->fdef().signature().name()); if (iter != records_.end()) { if (!FunctionDefsEqual(iter->second->fdef(), registration->fdef())) { return errors::InvalidArgument( From f54489210a33d9f6dd830cf4e07924d78cd0a9d6 Mon Sep 17 00:00:00 2001 From: Calliope Kostopoulou Date: Thu, 25 Apr 2024 08:03:35 +0000 Subject: [PATCH 18/53] minimizing changes --- tensorflow/core/framework/op.cc | 2 -- 1 file changed, 2 deletions(-) diff --git a/tensorflow/core/framework/op.cc b/tensorflow/core/framework/op.cc index 6cfec232612747..ccd5edcb3d37b5 100644 --- a/tensorflow/core/framework/op.cc +++ b/tensorflow/core/framework/op.cc @@ -15,8 +15,6 @@ limitations under the License. #include "tensorflow/core/framework/op.h" -#include - #include #include #include From 585a59ec5398001236cd42b24ffe3f30163732b8 Mon Sep 17 00:00:00 2001 From: Calliope Kostopoulou Date: Thu, 25 Apr 2024 08:05:55 +0000 Subject: [PATCH 19/53] minimizing changes --- tensorflow/core/graph/node_builder.cc | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tensorflow/core/graph/node_builder.cc b/tensorflow/core/graph/node_builder.cc index 526e52b009a240..dbd8fafd1ea523 100644 --- a/tensorflow/core/graph/node_builder.cc +++ b/tensorflow/core/graph/node_builder.cc @@ -133,10 +133,11 @@ Status NodeBuilder::Finalize(Graph* graph, Node** created_node, bool consume) { } NodeDef node_def; - TF_RETURN_IF_ERROR(def_builder_.Finalize(&node_def, consume)); + TF_RETURN_IF_ERROR(def_builder_.Finalize(&node_def, consume)); TF_RETURN_IF_ERROR(ValidateNodeDef(node_def, def_builder_.op_def())); TF_RETURN_IF_ERROR( CheckOpDeprecation(def_builder_.op_def(), graph->versions().producer())); + TF_ASSIGN_OR_RETURN(Node * node, graph->AddNode(std::move(node_def))); node->set_assigned_device_name(assigned_device_); From 545d7db34bc24f127c8eb3097b351ab7c91058d1 Mon Sep 17 00:00:00 2001 From: Calliope Kostopoulou Date: Thu, 25 Apr 2024 08:17:44 +0000 Subject: [PATCH 20/53] minimizing changes + additions in build --- tensorflow/core/BUILD | 4 +++- tensorflow/python/autograph/impl/api.py | 6 ------ 2 files changed, 3 insertions(+), 7 deletions(-) diff --git a/tensorflow/core/BUILD b/tensorflow/core/BUILD index ef906e41b25579..840ed526c72b76 100644 --- a/tensorflow/core/BUILD +++ b/tensorflow/core/BUILD @@ -579,6 +579,7 @@ cc_library( "//tensorflow/core/kernels:fact_op", "//tensorflow/core/kernels:fake_quant_ops", "//tensorflow/core/kernels:filesystem_ops", + "//tensorflow/core/kernels:function_control_ops", "//tensorflow/core/kernels:function_ops", "//tensorflow/core/kernels:functional_ops", "//tensorflow/core/kernels:grappler", @@ -915,6 +916,7 @@ filegroup( "encode_proto_ops_op_lib", "experimental_dataset_ops_op_lib", "filesystem_ops_op_lib", + "function_control_ops_op_lib", "function_ops_op_lib", "functional_grad", "functional_ops_op_lib", @@ -1764,7 +1766,7 @@ alias( tf_cuda_library( name = "graph", srcs = ["//tensorflow/core/graph:graph_srcs"], - hdrs = ["//tensorflow/core/graph:graph_headers","//tensorflow/core/common_runtime:graph_constructor.h"], + hdrs = ["//tensorflow/core/graph:graph_headers"], deps = [ ":framework", ":framework_internal", diff --git a/tensorflow/python/autograph/impl/api.py b/tensorflow/python/autograph/impl/api.py index c89e995bd461b8..64ddf316a88737 100644 --- a/tensorflow/python/autograph/impl/api.py +++ b/tensorflow/python/autograph/impl/api.py @@ -321,8 +321,6 @@ def converted_call(f, args, kwargs, caller_fn_scope=None, options=None): logging.log(1, 'Converted call: %s\n args: %s\n kwargs: %s\n', f, args, kwargs) - - if options is None: if caller_fn_scope is None: raise ValueError('either caller_fn_scope or options must have a value') @@ -340,10 +338,6 @@ def converted_call(f, args, kwargs, caller_fn_scope=None, options=None): logging.log(2, 'Permanently allowed: %s: AutoGraph artifact', f) return _call_unconverted(f, args, kwargs, options) - - - - # If this is a partial, unwrap it and redo all the checks. if isinstance(f, functools.partial): new_kwargs = {} From 630bbb18d65ea9977afae25bcbb49eee0fb67399 Mon Sep 17 00:00:00 2001 From: Calliope Kostopoulou Date: Thu, 25 Apr 2024 08:18:35 +0000 Subject: [PATCH 21/53] minimizing changes --- tensorflow/python/autograph/impl/api.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow/python/autograph/impl/api.py b/tensorflow/python/autograph/impl/api.py index 64ddf316a88737..b60c457599e717 100644 --- a/tensorflow/python/autograph/impl/api.py +++ b/tensorflow/python/autograph/impl/api.py @@ -320,7 +320,7 @@ def converted_call(f, args, kwargs, caller_fn_scope=None, options=None): """ logging.log(1, 'Converted call: %s\n args: %s\n kwargs: %s\n', f, args, kwargs) - + if options is None: if caller_fn_scope is None: raise ValueError('either caller_fn_scope or options must have a value') From 3dca7bcf2ed7c3f8da0607eada0adebcf0b6408c Mon Sep 17 00:00:00 2001 From: Calliope Kostopoulou Date: Thu, 25 Apr 2024 09:15:47 +0000 Subject: [PATCH 22/53] adding back some changes --- TESTS/comp | 7 ------- recipe.txt | 11 +++++++---- recursion-tests/comp | 5 +++++ {TESTS => recursion-tests}/distributed/d2.py | 0 {TESTS => recursion-tests}/distributed/d3.py | 0 {TESTS => recursion-tests}/distributed/distr.py | 0 {TESTS => recursion-tests}/factorial.py | 0 {TESTS => recursion-tests}/nohup.out | 0 {TESTS => recursion-tests}/test.py | 0 {TESTS => recursion-tests}/test2.py | 0 tensorflow/core/BUILD | 2 +- .../core/common_runtime/propagator_state.h | 16 ++++++++-------- .../core/grappler/optimizers/meta_optimizer.cc | 1 - tensorflow/core/kernels/function_control_ops.cc | 4 ++-- 14 files changed, 23 insertions(+), 23 deletions(-) delete mode 100755 TESTS/comp create mode 100755 recursion-tests/comp rename {TESTS => recursion-tests}/distributed/d2.py (100%) rename {TESTS => recursion-tests}/distributed/d3.py (100%) rename {TESTS => recursion-tests}/distributed/distr.py (100%) rename {TESTS => recursion-tests}/factorial.py (100%) rename {TESTS => recursion-tests}/nohup.out (100%) rename {TESTS => recursion-tests}/test.py (100%) rename {TESTS => recursion-tests}/test2.py (100%) diff --git a/TESTS/comp b/TESTS/comp deleted file mode 100755 index 5e3449fb10e6e9..00000000000000 --- a/TESTS/comp +++ /dev/null @@ -1,7 +0,0 @@ -#!/bin/bash - -cd .. -bazel build --disk_cache=~/mycache --config=dbg //tensorflow/tools/pip_package:build_pip_package && -./bazel-bin/tensorflow/tools/pip_package/build_pip_package /mnt && -pip uninstall tensorflow -y && -pip install /mnt/tensorflow-*.whl diff --git a/recipe.txt b/recipe.txt index 7f581ee9fcde91..c7f6bba9f3ac73 100644 --- a/recipe.txt +++ b/recipe.txt @@ -17,7 +17,6 @@ COPY ./tensorflow /tensorflow RUN apt-get update && apt-get install clang -y \ && apt-get install -y gdb \ - && apt-get install -y patchelf && apt-get install -y git RUN echo "deb [arch=amd64] https://storage.googleapis.com/bazel-apt stable jdk1.8" | tee /etc/apt/sources.list.d/bazel.list > /dev/null @@ -26,10 +25,14 @@ RUN curl https://bazel.build/bazel-release.pub.gpg | apt-key add - RUN apt update RUN apt install -y bazel-6.5.0 RUN apt install -y bazel -WORKDIR /tensorflow -RUN bazel build --config=dbg //tensorflow/tools/pip_package:build_pip_package -# RUN ./bazel-bin/tensorflow/tools/pip_package/build_pip_package /mnt +RUN git clone https://github.com/GeorgeVasilakopoulos/tensorflow.git +WORKDIR tensorflow +RUN git checkout recursion + + +bazel build --config=dbg //tensorflow/tools/pip_package:wheel --repo_env=WHEEL_NAME=tensorflow_cpu +pip install bazel-bin/tensorflow/tools/pip_package/wheel_house/tensorflow_cpu-2.17.0-cp311-cp311-linux_x86_64.whl ############################################################# diff --git a/recursion-tests/comp b/recursion-tests/comp new file mode 100755 index 00000000000000..90d6846bfdc8c0 --- /dev/null +++ b/recursion-tests/comp @@ -0,0 +1,5 @@ +#!/bin/bash + +cd .. +bazel build --config=dbg //tensorflow/tools/pip_package:wheel --repo_env=WHEEL_NAME=tensorflow_cpu && +pip install bazel-bin/tensorflow/tools/pip_package/wheel_house/tensorflow_cpu-2.17.0-cp311-cp311-linux_x86_64.whl diff --git a/TESTS/distributed/d2.py b/recursion-tests/distributed/d2.py similarity index 100% rename from TESTS/distributed/d2.py rename to recursion-tests/distributed/d2.py diff --git a/TESTS/distributed/d3.py b/recursion-tests/distributed/d3.py similarity index 100% rename from TESTS/distributed/d3.py rename to recursion-tests/distributed/d3.py diff --git a/TESTS/distributed/distr.py b/recursion-tests/distributed/distr.py similarity index 100% rename from TESTS/distributed/distr.py rename to recursion-tests/distributed/distr.py diff --git a/TESTS/factorial.py b/recursion-tests/factorial.py similarity index 100% rename from TESTS/factorial.py rename to recursion-tests/factorial.py diff --git a/TESTS/nohup.out b/recursion-tests/nohup.out similarity index 100% rename from TESTS/nohup.out rename to recursion-tests/nohup.out diff --git a/TESTS/test.py b/recursion-tests/test.py similarity index 100% rename from TESTS/test.py rename to recursion-tests/test.py diff --git a/TESTS/test2.py b/recursion-tests/test2.py similarity index 100% rename from TESTS/test2.py rename to recursion-tests/test2.py diff --git a/tensorflow/core/BUILD b/tensorflow/core/BUILD index 840ed526c72b76..a6749b3aeb915c 100644 --- a/tensorflow/core/BUILD +++ b/tensorflow/core/BUILD @@ -1766,7 +1766,7 @@ alias( tf_cuda_library( name = "graph", srcs = ["//tensorflow/core/graph:graph_srcs"], - hdrs = ["//tensorflow/core/graph:graph_headers"], + hdrs = ["//tensorflow/core/graph:graph_headers", "//tensorflow/core/common_runtime:graph_constructor.h"], deps = [ ":framework", ":framework_internal", diff --git a/tensorflow/core/common_runtime/propagator_state.h b/tensorflow/core/common_runtime/propagator_state.h index 79b3ed9d8e31ae..0f9099973ba641 100644 --- a/tensorflow/core/common_runtime/propagator_state.h +++ b/tensorflow/core/common_runtime/propagator_state.h @@ -283,6 +283,14 @@ class PropagatorState { // The number of outstanding iterations. int num_outstanding_iterations TF_GUARDED_BY(mu) = 1; + // Mapping from frame ID to outstanding frames. A new frame is created + // at some iteration of an active frame. So the unique key for the new + // child frame is a hash composed of the ID of the parent frame, the iteration + // number at which the parent frame is creating the new frame, and the + // name of the new frame from nodedef. + absl::flat_hash_map outstanding_child_frames_ + TF_GUARDED_BY(mu); + private: // The active iteration states of this frame. gtl::InlinedVector iterations; @@ -540,14 +548,6 @@ class PropagatorState { // The root frame in which the execution of this step is started. FrameState* root_frame_; - // Mapping from frame ID to outstanding frames. A new frame is created - // at some iteration of an active frame. So the unique key for the new - // child frame is a hash composed of the ID of the parent frame, the iteration - // number at which the parent frame is creating the new frame, and the - // name of the new frame from nodedef. - absl::flat_hash_map outstanding_frames_ - TF_GUARDED_BY(mu_); - PropagatorState(const PropagatorState&) = delete; void operator=(const PropagatorState&) = delete; }; diff --git a/tensorflow/core/grappler/optimizers/meta_optimizer.cc b/tensorflow/core/grappler/optimizers/meta_optimizer.cc index 88391a2abeb1a5..bfeee9ecec9115 100644 --- a/tensorflow/core/grappler/optimizers/meta_optimizer.cc +++ b/tensorflow/core/grappler/optimizers/meta_optimizer.cc @@ -268,7 +268,6 @@ std::unique_ptr MetaOptimizer::MakeNewOptimizer( MK_OPT("pin_to_host", "pin_to_host_optimization", new PinToHostOptimizer(cfg_.pin_to_host_optimization())); - return std::unique_ptr(); } diff --git a/tensorflow/core/kernels/function_control_ops.cc b/tensorflow/core/kernels/function_control_ops.cc index a22c0791022613..b49cbc3471456d 100644 --- a/tensorflow/core/kernels/function_control_ops.cc +++ b/tensorflow/core/kernels/function_control_ops.cc @@ -105,8 +105,8 @@ REGISTER_SYCL_HOST_KERNEL(ResourceHandle); REGISTER_GPU_HOST_KERNEL(int32); REGISTER_GPU_HOST_REF_KERNEL(int32); -REGISTER_GPU_HOST_KERNEL(string); -REGISTER_GPU_HOST_REF_KERNEL(string); +REGISTER_GPU_HOST_KERNEL(tstring); +REGISTER_GPU_HOST_REF_KERNEL(tstring); REGISTER_GPU_HOST_KERNEL(ResourceHandle); #undef REGISTER_GPU_HOST_KERNEL From a8588d4d2354f5b7cf996a1414b868cc21e6d658 Mon Sep 17 00:00:00 2001 From: Calliope Kostopoulou Date: Thu, 25 Apr 2024 09:19:24 +0000 Subject: [PATCH 23/53] fix --- tensorflow/core/kernels/function_control_ops.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow/core/kernels/function_control_ops.cc b/tensorflow/core/kernels/function_control_ops.cc index b49cbc3471456d..89d5a356427031 100644 --- a/tensorflow/core/kernels/function_control_ops.cc +++ b/tensorflow/core/kernels/function_control_ops.cc @@ -184,7 +184,7 @@ REGISTER_SYCL_HOST_KERNEL(string); ReturnOp) REGISTER_GPU_HOST_KERNEL(int32); -REGISTER_GPU_HOST_KERNEL(string); +REGISTER_GPU_HOST_KERNEL(tstring); #undef REGISTER_GPU_HOST_KERNEL From 4e95e5e700543e3e7d9c355e63391b51a8c5b55c Mon Sep 17 00:00:00 2001 From: Calliope Kostopoulou Date: Thu, 25 Apr 2024 09:47:54 +0000 Subject: [PATCH 24/53] upgrade instructions --- recipe.txt | 1 + 1 file changed, 1 insertion(+) diff --git a/recipe.txt b/recipe.txt index c7f6bba9f3ac73..30dd6eff463f7c 100644 --- a/recipe.txt +++ b/recipe.txt @@ -30,6 +30,7 @@ RUN git clone https://github.com/GeorgeVasilakopoulos/tensorflow.git WORKDIR tensorflow RUN git checkout recursion +pip install --upgrade pip setuptools wheel bazel build --config=dbg //tensorflow/tools/pip_package:wheel --repo_env=WHEEL_NAME=tensorflow_cpu pip install bazel-bin/tensorflow/tools/pip_package/wheel_house/tensorflow_cpu-2.17.0-cp311-cp311-linux_x86_64.whl From 9b21ab24a04ba0a20bf2733c178963115825b932 Mon Sep 17 00:00:00 2001 From: George Vasilakopoulos Date: Thu, 9 May 2024 10:23:43 +0000 Subject: [PATCH 25/53] function_transformation should run once --- tensorflow/core/grappler/optimizers/meta_optimizer.cc | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tensorflow/core/grappler/optimizers/meta_optimizer.cc b/tensorflow/core/grappler/optimizers/meta_optimizer.cc index bfeee9ecec9115..21f666d509539d 100644 --- a/tensorflow/core/grappler/optimizers/meta_optimizer.cc +++ b/tensorflow/core/grappler/optimizers/meta_optimizer.cc @@ -108,7 +108,7 @@ int NumIterations(const RewriterConfig& cfg) { bool IsRunOnceOptimizer(const string& name) { return name == "layout" || name == "memory_optimizer" || name == "loop_optimizer" || - absl::StartsWith(name, "auto_mixed_precision") || name == "function_optimizer"; + absl::StartsWith(name, "auto_mixed_precision") || name == "function_transformation"; } // Creates a function library stub from a real function library: copy only @@ -350,9 +350,9 @@ Status MetaOptimizer::InitializeOptimizers( USER_IS_EXPERIMENTAL_BOTH(function_optimization)) { VLOG(2) << "function_optimization is not implemented in TFG yet"; } else { - optimizers->push_back(std::make_unique( - cfg_.function_optimization(), - /*lower_control_flow=*/LowerControlFlow())); + // optimizers->push_back(std::make_unique( + // cfg_.function_optimization(), + // /*lower_control_flow=*/LowerControlFlow())); } } if (BOTH_NOT_OFF(common_subgraph_elimination) && From d64bddf770c591d9224397b9aedbadda665fb866 Mon Sep 17 00:00:00 2001 From: George Vasilakopoulos Date: Thu, 9 May 2024 10:24:45 +0000 Subject: [PATCH 26/53] Skip FunctionDef creation for Return Nodes --- tensorflow/core/grappler/utils/functions.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow/core/grappler/utils/functions.cc b/tensorflow/core/grappler/utils/functions.cc index c2d848aaa67ae6..94b7e3cadc39f0 100644 --- a/tensorflow/core/grappler/utils/functions.cc +++ b/tensorflow/core/grappler/utils/functions.cc @@ -608,7 +608,7 @@ Status MakeFunctionDef(const GrapplerFunctionItem& item, // Skip original `_Arg` and `_Retval` nodes. If node was converted to some // other type (e.g. inputs converted to placeholders), we need to check that // it's not registered as function input or output node. - if (IsArg(func_node) || IsRetval(func_node) || + if (IsArg(func_node) || IsRetval(func_node) || IsReturn(func_node) || helper.IsInputNode(func_node) || helper.IsOutputNode(func_node)) continue; From e51d338121188bee78561ac7cc6c09f08e3dbbac Mon Sep 17 00:00:00 2001 From: George Vasilakopoulos Date: Thu, 9 May 2024 10:25:14 +0000 Subject: [PATCH 27/53] flib_def could be NULL --- tensorflow/core/graph/graph_partition.cc | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/tensorflow/core/graph/graph_partition.cc b/tensorflow/core/graph/graph_partition.cc index e899db008e740d..9893f4892917e8 100644 --- a/tensorflow/core/graph/graph_partition.cc +++ b/tensorflow/core/graph/graph_partition.cc @@ -1491,7 +1491,16 @@ Status AddFunctionStateMachines(const PartitionOptions& opts, FuncInfo funcInfo; int nodes_num = g->num_node_ids(); - const FunctionDefLibrary& fdef = opts.flib_def->ToProto(); + + const FunctionLibraryDefinition* flib_def = opts.flib_def; + if(flib_def == nullptr){ + flib_def = &(g->flib_def()); + } + + const FunctionDefLibrary& fdef = flib_def->ToProto(); + + + for (const FunctionDef& func : fdef.function()) { int num_inputs = func.signature().input_arg_size(); From 83e1ebc2f94adb529f2ef5cd1048e466bc71897b Mon Sep 17 00:00:00 2001 From: George Vasilakopoulos Date: Tue, 14 May 2024 10:59:20 +0000 Subject: [PATCH 28/53] What was intended here exactly? --- tensorflow/core/common_runtime/graph_execution_state.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow/core/common_runtime/graph_execution_state.h b/tensorflow/core/common_runtime/graph_execution_state.h index 87b3a12891d45a..f118822a9525e5 100644 --- a/tensorflow/core/common_runtime/graph_execution_state.h +++ b/tensorflow/core/common_runtime/graph_execution_state.h @@ -58,7 +58,7 @@ struct ClientGraph { DataTypeVector feed_types, DataTypeVector fetch_types, int64_t collective_graph_key) : flib_def(std::move(flib)), - graph(flib_def.get()), + graph(*flib_def), feed_types(std::move(feed_types)), fetch_types(std::move(fetch_types)), collective_graph_key(collective_graph_key) {} From 6a951c9e33044e061e6a7ded2915b47c9cae4e92 Mon Sep 17 00:00:00 2001 From: George Vasilakopoulos Date: Tue, 14 May 2024 12:17:39 +0000 Subject: [PATCH 29/53] Fix duplicate flib imports --- tensorflow/core/common_runtime/direct_session.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow/core/common_runtime/direct_session.cc b/tensorflow/core/common_runtime/direct_session.cc index 50285f87b2283c..c90606be3c421e 100644 --- a/tensorflow/core/common_runtime/direct_session.cc +++ b/tensorflow/core/common_runtime/direct_session.cc @@ -1722,7 +1722,7 @@ Status DirectSession::CreateGraphs( for (auto& partition : partitions) { std::unique_ptr device_graph( - new Graph(client_graph->flib_def.get())); + new Graph(OpRegistry::Global())); device_graph->SetConstructionContext(ConstructionContext::kDirectSession); GraphConstructorOptions device_opts; // There are internal operations (e.g., send/recv) that we now allow. From bb167f8b738ef0b360435450227c7ab92bb3fed9 Mon Sep 17 00:00:00 2001 From: George Vasilakopoulos Date: Sat, 18 May 2024 21:56:58 +0000 Subject: [PATCH 30/53] Topological Ordering likely fixed --- tensorflow/core/grappler/utils/topological_sort.cc | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/tensorflow/core/grappler/utils/topological_sort.cc b/tensorflow/core/grappler/utils/topological_sort.cc index e6b70c3c2c4fe5..1d2d7b7bc681ff 100644 --- a/tensorflow/core/grappler/utils/topological_sort.cc +++ b/tensorflow/core/grappler/utils/topological_sort.cc @@ -71,7 +71,7 @@ Status ComputeTopologicalOrder( } bool recursion_merge = false; if (IsMerge(graph.node(i))) { - for (int input : graph_view.GetFanin(i)) { + for (int input : graph_view.GetFanin(i)) { if (IsNextIteration(graph.node(input))) { num_ready_inputs[i]++; } @@ -88,6 +88,8 @@ Status ComputeTopologicalOrder( // Nodes that send their output to "Return" nodes are // function Returning Nodes and in case of recursive functions // those nodes are part of graph cycles. + int id = 0; + num_ready_inputs[i] = 0; for (int input : graph_view.GetFanin(i)) { // In order to detect the recursion cycles we depend on // the fact that a recursive function's returning node, @@ -95,21 +97,22 @@ Status ComputeTopologicalOrder( // with different "call_id" attributes (same "call_id" // attrs would mean that they belong in the same function call // but they correspond to different function outputs) - // if (!StringPiece(graph.node(input)).starts_with("^")) { - if (true) { + if (!absl::StartsWith(graph.node(i).input(id), "^")) { + // if (true) { int call_id; TF_CHECK_OK(GetNodeAttr(graph.node(i), "call_id", &call_id)); returning_nodes[input].emplace(call_id); + num_ready_inputs[i]++; } + id++; } - num_ready_inputs[i] = 0; } } for (const auto& retnode : returning_nodes) { if (retnode.second.size() > 1) { // Detected Cycle - num_ready_inputs[retnode.first]++; + // num_ready_inputs[retnode.first]++; } } From 8fa6cb72dafa924ee6c16306f3eeba2afb0ca314 Mon Sep 17 00:00:00 2001 From: George Vasilakopoulos Date: Sat, 18 May 2024 21:58:12 +0000 Subject: [PATCH 31/53] Disable control flow v2 in test file --- recursion-tests/test2.py | 11 ++++------- 1 file changed, 4 insertions(+), 7 deletions(-) diff --git a/recursion-tests/test2.py b/recursion-tests/test2.py index d27c5f4af1f587..673c56675ac30f 100644 --- a/recursion-tests/test2.py +++ b/recursion-tests/test2.py @@ -10,15 +10,12 @@ tf.compat.v1.disable_eager_execution() +tf.compat.v1.disable_control_flow_v2() # tf.logging.set_verbosity(tf.logging.INFO) fac = function.Declare("Fac", [("n", tf.int32)], [("ret", tf.int32)]) -@function.Defun(tf.int32, func_name="Test", out_names=["ret"]) -def t(n): - return tf.constant(1) - # fac = function.Declare("Fac", [("n", tf.int32)], [("ret", tf.int32)]) @@ -27,7 +24,7 @@ def t(n): def FacImpl(n): return tf.cond(tf.less_equal(n, 1), lambda: tf.constant(1), - lambda: n * fac(n - 1)) + lambda: n * fac(n - 1)) # @function.Defun(tf.int32, func_name="Fac", out_names=["ret"]) @@ -43,10 +40,10 @@ def FacImpl(n): print(tf.compat.v1.get_default_graph().as_graph_def()) -# writer = tf.compat.v1.summary.FileWriter('/tensorflow/TESTS/graph', tf.compat.v1.get_default_graph()) +writer = tf.compat.v1.summary.FileWriter('/tensorflow/recursion-tests/graph', tf.compat.v1.get_default_graph()) # writer = tf.summary.FileWriter('./graphs', tf.compat.v1.get_default_graph()) with tf.compat.v1.Session() as sess: - result = FacImpl(1) + result = FacImpl(10) print("Result:", sess.run(result)) From 894ea9da03a79f8ea9249496f3964245444a69f6 Mon Sep 17 00:00:00 2001 From: George Vasilakopoulos Date: Sat, 18 May 2024 21:59:04 +0000 Subject: [PATCH 32/53] Update Compilation File --- recursion-tests/comp | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/recursion-tests/comp b/recursion-tests/comp index 90d6846bfdc8c0..90fa7e03907f60 100755 --- a/recursion-tests/comp +++ b/recursion-tests/comp @@ -1,5 +1,6 @@ #!/bin/bash cd .. -bazel build --config=dbg //tensorflow/tools/pip_package:wheel --repo_env=WHEEL_NAME=tensorflow_cpu && +bazel build --disk_cache=/root/mycache --per_file_copt=+tensorflow.*,-tensorflow/compiler.*,-tensorflow/lite.*,-tensorflow/core/kernels.*@-O0,-g //tensorflow/tools/pip_package:wheel --repo_env=WHEEL_NAME=tensorflow_cpu && +pip uninstall tensorflow_cpu -y && pip install bazel-bin/tensorflow/tools/pip_package/wheel_house/tensorflow_cpu-2.17.0-cp311-cp311-linux_x86_64.whl From 41058c8fc95e9db263daee0b05644ed150fb1520 Mon Sep 17 00:00:00 2001 From: George Vasilakopoulos Date: Mon, 20 May 2024 12:24:41 +0000 Subject: [PATCH 33/53] Graph construction seems to be working --- .../core/common_runtime/graph_constructor.cc | 41 +++++++++++++------ .../optimizers/function_transformation.cc | 10 +++-- 2 files changed, 36 insertions(+), 15 deletions(-) diff --git a/tensorflow/core/common_runtime/graph_constructor.cc b/tensorflow/core/common_runtime/graph_constructor.cc index bc3c98787bd685..0bba0bf4881dce 100644 --- a/tensorflow/core/common_runtime/graph_constructor.cc +++ b/tensorflow/core/common_runtime/graph_constructor.cc @@ -301,7 +301,7 @@ class GraphConstructor { // Decrement pending count for users of `processed` and add the ones that now // have all of their pending inputs satisfied to `ready_`. - void UpdatePendingCountAndReady(int processed, bool is_next_iteration); + void UpdatePendingCountAndReady(int processed, bool is_next_iteration, bool is_function_call); // Subclasses override the following virtual methods to provide efficient // access to the original protocol buffer-based graph. @@ -576,20 +576,21 @@ Status MaybeAppendVersionWarning(const VersionDef* versions, } void GraphConstructor::UpdatePendingCountAndReady(int processed, - bool is_next_iteration) { + bool is_next_iteration, bool is_function_call) { for (size_t i = 0; i < outputs_[processed].size(); ++i) { const int output = outputs_[processed][i]; // We didn't consider NextIteration->Merge edges when computing // pending_counts_ so we should not have to consider it here either. bool is_next_iteration_to_merge_edge = is_next_iteration && merge_node_indices_.count(output) == 1; - if (!is_next_iteration_to_merge_edge) { - int* current_pending_count = &pending_count_[output]; - CHECK_GT(*current_pending_count, 0); - (*current_pending_count)--; - if (*current_pending_count == 0) { - ready_.insert(output); - } + if (is_next_iteration_to_merge_edge)continue; + int* current_pending_count = &pending_count_[output]; + if (*current_pending_count == 0 && is_function_call) continue; + if (*current_pending_count == 0 && merge_node_indices_.count(output) == 1) continue; + CHECK_GT(*current_pending_count, 0); + (*current_pending_count)--; + if (*current_pending_count == 0) { + ready_.insert(output); } } } @@ -784,11 +785,25 @@ Status GraphConstructor::InitFromEdges() { } gtl::FlatSet call_nodes; + gtl::FlatSet merge_return_nodes; for (int n = 0; n < node_def_count(); ++n) { const NodeDef& node_def = get_node_def(n); if (IsCall(node_def)) { call_nodes.insert(node_def.name()); } + if (!IsMerge(node_def) && IsReturningNode(node_def)){ + for (const auto& input_name : node_def.input()) { + if (!absl::StartsWith(input_name, "^")) { + string prevNode = input_name; + size_t pos = input_name.find(":"); + + if (pos != std::string::npos) + prevNode = input_name.substr(0, pos); + + merge_return_nodes.insert(prevNode); + } + } + } } @@ -816,7 +831,9 @@ Status GraphConstructor::InitFromEdges() { if (next_iteration_nodes.find(string(id.first)) != next_iteration_nodes.end()|| call_nodes.find(string(id.first)) != - call_nodes.end()) { + call_nodes.end()|| + merge_return_nodes.find(node_def.name()) != + merge_return_nodes.end()) { has_loop_back_edge = true; } } @@ -1296,7 +1313,7 @@ Status GraphConstructor::Convert() { TF_RETURN_IF_ERROR(IsNodeFullyMapped(node_def, &is_node_mapped)); if (is_node_mapped) { // Skip this node after updating pending_count_ for outputs - UpdatePendingCountAndReady(o, IsNextIteration(node_def)); + UpdatePendingCountAndReady(o, IsNextIteration(node_def), IsCall(node_def)); continue; } } @@ -1422,7 +1439,7 @@ Status GraphConstructor::Convert() { TF_RETURN_IF_ERROR(ValidateShape(node)); // Update pending_count_ for outputs. - UpdatePendingCountAndReady(o, node->IsNextIteration()); + UpdatePendingCountAndReady(o, node->IsNextIteration(), node->IsCall()); } if (processed < node_def_count()) { diff --git a/tensorflow/core/grappler/optimizers/function_transformation.cc b/tensorflow/core/grappler/optimizers/function_transformation.cc index 465d0a844a9730..0852d6370debe8 100644 --- a/tensorflow/core/grappler/optimizers/function_transformation.cc +++ b/tensorflow/core/grappler/optimizers/function_transformation.cc @@ -306,7 +306,7 @@ Status CallRewriter::CollectCalls(std::vector& calls) { // outputFile << "In collect calls: "<< std::endl; for (NodeDef& node : *graph->mutable_node()) { const FunctionDef* func = ctx.FindInlinedFunction(node.op()); - // outputFile << "Collecting Calls: "<< node.name() << " " << node.op() << std::endl; + // printf("Collecting Calls: %s %s \n", node.name().c_str(), node.op().c_str()); if (func != nullptr) { CallInfo call; call.call_id = GetCallId(node); @@ -401,7 +401,7 @@ Status CallRewriter::TransformCall(CallInfo& call_info) { std::vector call_nodes; std::vector ret_nodes; - + call_nodes.resize(func_info.inputs.size()); for (unsigned int arg_num = 0; arg_num < func_info.inputs.size(); arg_num++) { call_nodes[arg_num] = graph->add_node(); @@ -519,7 +519,7 @@ Status InlineFunction(const FunctionDef& func_def, if (input_it != input_nodes.end()) { CHECK_EQ(0, func_body_node.input_size()); // Turn input placeholders into identity nodes - if (IsPlaceholder(func_body_node)) { + if (IsArg(func_body_node)) { func_body_node.set_op(kIdentityOp); } // Connect merge with input arg @@ -530,6 +530,10 @@ Status InlineFunction(const FunctionDef& func_def, for (string& input : *func_body_node.mutable_input()) { input = AddPrefixToNodeName(input, prefix); } + // If this is a return node, change the op to KIdentityOp + if(IsRetval(func_body_node)){ + func_body_node.set_op(kIdentityOp); + } // If the node has no input, make hook it up to the Merge nodes to ensure // it runs in the same frame as the other nodes of the function body. if (func_body_node.input_size() == 0) { From b37b5e62354bfeebbf713804e8ce4468a3efc441 Mon Sep 17 00:00:00 2001 From: George Vasilakopoulos Date: Wed, 22 May 2024 12:43:30 +0000 Subject: [PATCH 34/53] Executor changes --- .../common_runtime/immutable_executor_state.cc | 14 ++++++++++++++ .../core/common_runtime/immutable_executor_state.h | 8 ++++++++ tensorflow/core/common_runtime/propagator_state.cc | 12 +++++++----- 3 files changed, 29 insertions(+), 5 deletions(-) diff --git a/tensorflow/core/common_runtime/immutable_executor_state.cc b/tensorflow/core/common_runtime/immutable_executor_state.cc index 1ef9ad362dd9c1..2bf3c0ba1ebe41 100644 --- a/tensorflow/core/common_runtime/immutable_executor_state.cc +++ b/tensorflow/core/common_runtime/immutable_executor_state.cc @@ -190,7 +190,21 @@ Status ImmutableExecutorState::Initialize(const Graph& graph) { } else { item->is_constant_enter = false; } + item->is_call = IsCall(n); + + if(item->is_call){ + string frame_name; + TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), "frame_name", &frame_name)); + FrameInfo* frame_info = frame_info_[frame_name].get(); + frame_info->parallel_iterations = 1; + if (call_frame_info_.size() <= id) { + call_frame_info_.resize(id + 1); + } + call_frame_info_[id] = frame_info; + } + + item->is_return = IsReturn(n); item->is_call_or_return = (IsCall(n) || IsReturn(n)); item->is_exit = IsExit(n); diff --git a/tensorflow/core/common_runtime/immutable_executor_state.h b/tensorflow/core/common_runtime/immutable_executor_state.h index a1fca080ca6c5c..320da16860c5d5 100644 --- a/tensorflow/core/common_runtime/immutable_executor_state.h +++ b/tensorflow/core/common_runtime/immutable_executor_state.h @@ -99,6 +99,13 @@ class ImmutableExecutorState { return *enter_frame_info_[node_item.node_id]; } + const FrameInfo& get_call_frame_info(const NodeItem& node_item) const { + DCHECK(node_item.is_call); + return *call_frame_info_[node_item.node_id]; + } + + + bool requires_control_flow_support() const { return requires_control_flow_; } // Copies the pending counts for nodes in this graph to the given array. @@ -147,6 +154,7 @@ class ImmutableExecutorState { // If the graph contains any "Enter" or "RefEnter" nodes, this vector maps // dense node IDs to the corresponding FrameInfo. std::vector enter_frame_info_; + std::vector call_frame_info_; // If `requires_control_flow_` is false, this points to an array of initial // pending counts for the nodes in the graph, indexed by node ID. diff --git a/tensorflow/core/common_runtime/propagator_state.cc b/tensorflow/core/common_runtime/propagator_state.cc index 4ec560ce2d7aa6..8a95fd386a8768 100644 --- a/tensorflow/core/common_runtime/propagator_state.cc +++ b/tensorflow/core/common_runtime/propagator_state.cc @@ -7,7 +7,7 @@ You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, +distributed under the License is distributed on a:n "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. @@ -97,8 +97,10 @@ void PropagatorState::PropagateOutputs(const TaggedNode& tagged_node, // printf("Propagate Outputs: %s, am i alive? %d\n", node->name().c_str(), !is_dead); // printf("Frame: %s\n", input_frame->frame_name.c_str()); - // printf("Propagate Outputs: %s, am i alive? %d\n", node->name().c_str(), !is_dead); - // printf("Frame: %s\n", input_frame->frame_name.c_str()); + string output(tagged_node.node_item->kernel->name_view()); + + printf("Propagate Outputs: %s, am i alive? %d\n",output.c_str(), !is_dead); + printf("Frame: %s\n", input_frame->frame_name.c_str()); if (!item->is_enter_exit_or_next_iter && !item->is_call_or_return) { @@ -149,7 +151,7 @@ void PropagatorState::PropagateOutputs(const TaggedNode& tagged_node, // output_frame = nullptr; // } else { FindOrCreateChildFrame(input_frame, input_iter, *item, &output_frame); - output_iter = 0; + output_iter = output_frame->GetIteration(0); { mutex_lock l(output_frame->mu); int activated = output_frame->ActivateNodesLocked( @@ -298,7 +300,7 @@ void PropagatorState::FindOrCreateChildFrame(FrameState* frame, FrameState** child) { // Get the child frame name. const ImmutableExecutorState::FrameInfo& frame_info = - immutable_state_.get_enter_frame_info(node_item); + node_item.is_enter ? immutable_state_.get_enter_frame_info(node_item) : immutable_state_.get_call_frame_info(node_item); const uint64 child_id = Hash64Combine( frame->frame_id, From 8c3d37fcc454a864d692c5e232a10051b6fc7291 Mon Sep 17 00:00:00 2001 From: George Vasilakopoulos Date: Thu, 23 May 2024 14:11:05 +0000 Subject: [PATCH 35/53] Include call_id when producing frame identifier --- tensorflow/core/common_runtime/propagator_state.cc | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/tensorflow/core/common_runtime/propagator_state.cc b/tensorflow/core/common_runtime/propagator_state.cc index 8a95fd386a8768..67a195ee72cc58 100644 --- a/tensorflow/core/common_runtime/propagator_state.cc +++ b/tensorflow/core/common_runtime/propagator_state.cc @@ -99,8 +99,8 @@ void PropagatorState::PropagateOutputs(const TaggedNode& tagged_node, string output(tagged_node.node_item->kernel->name_view()); - printf("Propagate Outputs: %s, am i alive? %d\n",output.c_str(), !is_dead); - printf("Frame: %s\n", input_frame->frame_name.c_str()); + // printf("Propagate Outputs: %s, am i alive? %d\n",output.c_str(), !is_dead); + // printf("Frame: %s\n", input_frame->frame_name.c_str()); if (!item->is_enter_exit_or_next_iter && !item->is_call_or_return) { @@ -151,6 +151,7 @@ void PropagatorState::PropagateOutputs(const TaggedNode& tagged_node, // output_frame = nullptr; // } else { FindOrCreateChildFrame(input_frame, input_iter, *item, &output_frame); + // printf("Inside Call: %s. Input frame id: %d, Output frame id %d\n", output.c_str(),input_frame->frame_id,output_frame->frame_id); output_iter = output_frame->GetIteration(0); { mutex_lock l(output_frame->mu); @@ -167,6 +168,7 @@ void PropagatorState::PropagateOutputs(const TaggedNode& tagged_node, // } else { output_frame = input_frame->parent_frame; output_iter = input_frame->parent_iter; + // printf("Inside Return: %s. Input frame id: %d, Output frame id %d\n", output.c_str(),input_frame->frame_id,output_frame->frame_id); { mutex_lock l(output_frame->mu); int activated = output_frame->ActivateNodesLocked( @@ -304,7 +306,7 @@ void PropagatorState::FindOrCreateChildFrame(FrameState* frame, const uint64 child_id = Hash64Combine( frame->frame_id, - Hash64Combine(iter_state->iter_num, Hash64(frame_info.name))); + Hash64Combine(iter_state->iter_num, Hash64(frame_info.name + ":" + std::to_string(node_item.call_id)))); { tf_shared_lock executor_lock(frame->mu); From 7ad869a7c551f326f564799133889f2cb97b9b25 Mon Sep 17 00:00:00 2001 From: George Vasilakopoulos Date: Thu, 23 May 2024 14:12:22 +0000 Subject: [PATCH 36/53] Breakpoint file --- file.gdb | 5 +++++ 1 file changed, 5 insertions(+) create mode 100644 file.gdb diff --git a/file.gdb b/file.gdb new file mode 100644 index 00000000000000..5bba091bbdcaa1 --- /dev/null +++ b/file.gdb @@ -0,0 +1,5 @@ +break core/common_runtime/immutable_executor_state.cc:89 +continue +break core/common_runtime/direct_session.cc:918 +break core/common_runtime/direct_session.cc:745 +break core/common_runtime/propagator_state.cc:106 \ No newline at end of file From 68404500f4b11fa028a63174aa1fbf1c28b17d37 Mon Sep 17 00:00:00 2001 From: George Vasilakopoulos Date: Sat, 25 May 2024 15:41:41 +0000 Subject: [PATCH 37/53] Adding some test files --- recursion-tests/distributed/d2.py | 6 ++++- recursion-tests/distributed/d3.py | 2 +- recursion-tests/distributed/distr.py | 38 +++++++++++----------------- recursion-tests/exponents.py | 37 +++++++++++++++++++++++++++ recursion-tests/takeuchi.py | 22 ++++++++++++++++ 5 files changed, 80 insertions(+), 25 deletions(-) create mode 100644 recursion-tests/exponents.py create mode 100644 recursion-tests/takeuchi.py diff --git a/recursion-tests/distributed/d2.py b/recursion-tests/distributed/d2.py index daa16012ea4bca..0a1114671dc8f5 100644 --- a/recursion-tests/distributed/d2.py +++ b/recursion-tests/distributed/d2.py @@ -1,8 +1,12 @@ +import os + +os.environ['TF_CPP_MAX_VLOG_LEVEL'] = '2' + import tensorflow as tf cluster_spec = { - "local": ["172.19.0.3:2222", "172.19.0.2:2223"] + "local": ["localhost:2222", "localhost:2223"] } server = tf.distribute.Server(cluster_spec, job_name="local", task_index=0) diff --git a/recursion-tests/distributed/d3.py b/recursion-tests/distributed/d3.py index 6f1bc6ccec6a68..d2a190329de22a 100644 --- a/recursion-tests/distributed/d3.py +++ b/recursion-tests/distributed/d3.py @@ -7,7 +7,7 @@ cluster_spec = { - "local": ["172.19.0.3:2222", "172.19.0.2:2223"] + "local": ["localhost:2222", "localhost:2223"] } server = tf.distribute.Server(cluster_spec, job_name="local", task_index=1) diff --git a/recursion-tests/distributed/distr.py b/recursion-tests/distributed/distr.py index 13805180109c25..615aaa61cad69e 100644 --- a/recursion-tests/distributed/distr.py +++ b/recursion-tests/distributed/distr.py @@ -7,46 +7,38 @@ from tensorflow.python.framework import function tf.compat.v1.disable_eager_execution() - -cluster = tf.train.ClusterSpec({"local": ["localhost:2222", "localhost:2223"]}) - -fib = function.Declare("Fib", [("n", tf.int32)], [("ret", tf.int32)]) - +tf.compat.v1.disable_control_flow_v2() +cluster = tf.train.ClusterSpec({"local": ["localhost:2222", "localhost:2223"]}) -@function.Defun(tf.int32, func_name="Minus2", out_names=["ret"]) -def M2(n): - return (n-2) +fac = function.Declare("Fac", [("n", tf.int32)], [("ret", tf.int32)]) -@function.Defun(tf.int32, func_name="Minus1", out_names=["ret"]) -def M1(n): - return (n-1) -@function.Defun(tf.int32, func_name="Fib", out_names=["ret"]) -def FibImpl(n): +@function.Defun(tf.int32, func_name="Fac", out_names=["ret"]) +def FacImpl(n): def f1(): - with tf.device("/job:local/replica:0/task:0/device:CPU:0"): + with tf.device("/job:local/replica:0/task:1/device:CPU:0"): ret = tf.constant(1) return ret def f2(): with tf.device("/job:local/replica:0/task:0/device:CPU:0"): - fib1 = M1(n) - with tf.device("/job:local/replica:0/task:0/device:CPU:0"): - fib2 = M2(n) - - return fib1 + fib2 + ret = n * fac(n - 1) + return ret + + with tf.device("/job:local/replica:0/task:0/device:CPU:0"): + pred = tf.less_equal(n, 1) - return tf.cond(tf.less_equal(n, 1), f1, f2) + return tf.cond(pred, f1, f2) -FibImpl.add_to_graph(tf.compat.v1.get_default_graph()) +FacImpl.add_to_graph(tf.compat.v1.get_default_graph()) -n = tf.constant(11) -x = fib(n) +n = tf.constant(10) +x = fac(n) #print(tf.get_default_graph().as_graph_def()) diff --git a/recursion-tests/exponents.py b/recursion-tests/exponents.py new file mode 100644 index 00000000000000..42dc40a90ef8e6 --- /dev/null +++ b/recursion-tests/exponents.py @@ -0,0 +1,37 @@ +import os +import tensorflow as tf +from tensorflow.python.framework import function + + + +tf.compat.v1.disable_eager_execution() +tf.compat.v1.disable_control_flow_v2() + +exp = function.Declare("EXPONENT", [("x", tf.int32), ("n", tf.int32)], [("ret", tf.int32)]) + + + + +# fac = function.Declare("Fac", [("n", tf.int32)], [("ret", tf.int32)]) + +@function.Defun(tf.int32, tf.int32, func_name="EXPONENT", out_names=["ret"]) +def ExpImpl(x, n): + return tf.cond(tf.equal(n,0), + lambda: tf.constant(1), + lambda: x*exp(x,n-1)) + +# @function.Defun(tf.int32, func_name="Fac", out_names=["ret"]) +# def FacImpl2(n): +# return t(1) + + +ExpImpl.add_to_graph(tf.compat.v1.get_default_graph()) +# t.add_to_graph(tf.compat.v1.get_default_graph()) +# FacImpl2.add_to_graph(tf.compat.v1.get_default_graph()) + +# writer = tf.summary.FileWriter('./graphs', tf.compat.v1.get_default_graph()) + +with tf.compat.v1.Session() as sess: + result = ExpImpl(2,5) + print("Result:", sess.run(result)) + diff --git a/recursion-tests/takeuchi.py b/recursion-tests/takeuchi.py new file mode 100644 index 00000000000000..0da79c39f64648 --- /dev/null +++ b/recursion-tests/takeuchi.py @@ -0,0 +1,22 @@ +import tensorflow as tf +from tensorflow.python.framework import function + +tf.compat.v1.disable_eager_execution() +tf.compat.v1.disable_control_flow_v2() + +tak = function.Declare("Tak", [("x", tf.int32), ("y", tf.int32), ("z", tf.int32)], [("ret", tf.int32)]) + +@function.Defun(tf.int32, tf.int32, tf.int32, func_name="Tak", out_names=["ret"]) +def TakImpl(x,y,z): + return tf.cond(tf.less(y, x), + lambda: tak(tak(x-1,y,z), tak(y-1,z,x), tak(z-1,x,y)), + lambda: z) + +TakImpl.add_to_graph(tf.compat.v1.get_default_graph()) + + +with tf.compat.v1.Session() as sess: + result = TakImpl(24,16,8) + print("Result:", sess.run(result)) + +#print(tf.get_default_graph().as_graph_def()) \ No newline at end of file From 12226d264d2f84508edeb31f880bb379c10962b9 Mon Sep 17 00:00:00 2001 From: George Vasilakopoulos Date: Sat, 25 May 2024 15:46:38 +0000 Subject: [PATCH 38/53] Avoid creating self - control edges --- tensorflow/core/graph/graph_partition.cc | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tensorflow/core/graph/graph_partition.cc b/tensorflow/core/graph/graph_partition.cc index 9893f4892917e8..1e7a5d0690de8e 100644 --- a/tensorflow/core/graph/graph_partition.cc +++ b/tensorflow/core/graph/graph_partition.cc @@ -1424,7 +1424,7 @@ Status CallingFunction(Graph* graph, GraphDef& main_graphDef, StateMachine& stat } } } else - ConnectMergeToNode(main_graphDef, merge->name(), out->name(), state_machine, dst_device); + if(merge->name() != out->name()) ConnectMergeToNode(main_graphDef, merge->name(), out->name(), state_machine, dst_device); } if (ready_inputs[out] == out->in_edges().size()) { @@ -1527,7 +1527,7 @@ Status AddFunctionStateMachines(const PartitionOptions& opts, // We convert graph to its equivalent graph_def, cause it's easier // to extend it with the GraphDef state machines of partitions g->ToGraphDef(&main_graphDef); - + while (!ready_nodes.empty()) { Node* ready_node = ready_nodes.front(); ready_nodes.pop_front(); From 80dc042d5257219dc1fa5b74789f401ad2501996 Mon Sep 17 00:00:00 2001 From: George Vasilakopoulos Date: Wed, 5 Jun 2024 17:17:07 +0000 Subject: [PATCH 39/53] Test code for how to find grad graph --- .../optimizers/function_transformation.cc | 17 +++++++++++++---- 1 file changed, 13 insertions(+), 4 deletions(-) diff --git a/tensorflow/core/grappler/optimizers/function_transformation.cc b/tensorflow/core/grappler/optimizers/function_transformation.cc index 0852d6370debe8..2984bf498212e6 100644 --- a/tensorflow/core/grappler/optimizers/function_transformation.cc +++ b/tensorflow/core/grappler/optimizers/function_transformation.cc @@ -22,6 +22,8 @@ limitations under the License. #include "tensorflow/core/util/events_writer.h" #include "tensorflow/core/common_runtime/graph_constructor.h" +#include "tensorflow/core/common_runtime/function.h" +#include "tensorflow/core/common_runtime/function_def_utils.h" #include "tensorflow/core/graph/tensor_id.h" #include "tensorflow/core/grappler/grappler_item.h" #include "tensorflow/core/grappler/op_types.h" @@ -89,6 +91,17 @@ class FunctionInliningContext { // if (func.attr().count("_noinline") != 0) { // continue; // } + + std::unique_ptr fbody; + Status stat = FunctionDefToBodyHelper( + func, AttrSlice(&func.attr()), &function_library_, &fbody); + + fbody = SymbolicGradient(*fbody); + + + + printf("Gradient graph %s\n\n",SummarizeGraphDef((fbody)->graph->ToGraphDefDebug()).c_str()); + // Don't touch anything marked XLA to prevent XLA failures further down // the road. if (func.attr().count("_XlaCompile") > 0 && @@ -599,10 +612,6 @@ Status FunctionTransformation::Optimize(Cluster* cluster, const GrapplerItem& it GraphDef* output) { - - // outputFile << "In Optimize" << std::endl; - - FunctionInliningContext ctx(item); CallRewriter call_rewriter(item, output, ctx); From 65deff8373aceb71f7f6f906a788111f4b0c0782 Mon Sep 17 00:00:00 2001 From: George Vasilakopoulos Date: Wed, 5 Jun 2024 17:17:23 +0000 Subject: [PATCH 40/53] Updated test file --- recursion-tests/exponents.py | 31 +++++++++++++++++++++++-------- 1 file changed, 23 insertions(+), 8 deletions(-) diff --git a/recursion-tests/exponents.py b/recursion-tests/exponents.py index 42dc40a90ef8e6..8ef9f320b3c8ea 100644 --- a/recursion-tests/exponents.py +++ b/recursion-tests/exponents.py @@ -2,23 +2,24 @@ import tensorflow as tf from tensorflow.python.framework import function - +os.environ['TF_CPP_MAX_VLOG_LEVEL'] = '2' tf.compat.v1.disable_eager_execution() tf.compat.v1.disable_control_flow_v2() -exp = function.Declare("EXPONENT", [("x", tf.int32), ("n", tf.int32)], [("ret", tf.int32)]) +exp = function.Declare("EXPONENT", [("x", tf.float32), ("n", tf.int32)], [("ret", tf.float32)]) # fac = function.Declare("Fac", [("n", tf.int32)], [("ret", tf.int32)]) -@function.Defun(tf.int32, tf.int32, func_name="EXPONENT", out_names=["ret"]) +@function.Defun(tf.float32, tf.int32, func_name="EXPONENT", out_names=["ret"]) def ExpImpl(x, n): return tf.cond(tf.equal(n,0), - lambda: tf.constant(1), - lambda: x*exp(x,n-1)) + lambda: tf.cast(tf.constant(1),tf.float32), + lambda: x*x) + # @function.Defun(tf.int32, func_name="Fac", out_names=["ret"]) # def FacImpl2(n): @@ -29,9 +30,23 @@ def ExpImpl(x, n): # t.add_to_graph(tf.compat.v1.get_default_graph()) # FacImpl2.add_to_graph(tf.compat.v1.get_default_graph()) + +x = tf.compat.v1.get_variable('n_var', [], initializer=tf.constant_initializer(4.0)) +y = ExpImpl(x,2) + +train_op = tf.compat.v1.train.GradientDescentOptimizer(0.01).minimize(y) +print(tf.compat.v1.get_default_graph().as_graph_def()) + + +sess = tf.compat.v1.Session() +sess.run(tf.compat.v1.initialize_all_variables()) +print(x.eval(session=sess)) +print(sess.run(train_op)) +print(x.eval(session=sess)) + # writer = tf.summary.FileWriter('./graphs', tf.compat.v1.get_default_graph()) -with tf.compat.v1.Session() as sess: - result = ExpImpl(2,5) - print("Result:", sess.run(result)) +# with tf.compat.v1.Session() as sess: +# result = ExpImpl(2,5) +# print("Result:", sess.run(result)) From 8ff57b731558608a4e63a17c6c3d64d62d8c56f8 Mon Sep 17 00:00:00 2001 From: George Vasilakopoulos Date: Thu, 6 Jun 2024 14:04:39 +0000 Subject: [PATCH 41/53] Autodiff code migration. Compileable --- .../optimizers/function_transformation.cc | 934 +++++++++++------- 1 file changed, 578 insertions(+), 356 deletions(-) diff --git a/tensorflow/core/grappler/optimizers/function_transformation.cc b/tensorflow/core/grappler/optimizers/function_transformation.cc index 2984bf498212e6..faad2f658775f1 100644 --- a/tensorflow/core/grappler/optimizers/function_transformation.cc +++ b/tensorflow/core/grappler/optimizers/function_transformation.cc @@ -11,9 +11,6 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include -// std::ofstream outputFile("/tensorflow/TESTS/mylog.txt"); - #include "tensorflow/core/grappler/optimizers/function_transformation.h" #include #include @@ -21,14 +18,7 @@ limitations under the License. #include "tensorflow/core/util/event.pb.h" #include "tensorflow/core/util/events_writer.h" -#include "tensorflow/core/common_runtime/graph_constructor.h" -#include "tensorflow/core/common_runtime/function.h" -#include "tensorflow/core/common_runtime/function_def_utils.h" #include "tensorflow/core/graph/tensor_id.h" -#include "tensorflow/core/grappler/grappler_item.h" -#include "tensorflow/core/grappler/op_types.h" -#include "tensorflow/core/grappler/utils.h" -#include "tensorflow/core/grappler/utils/functions.h" #include "tensorflow/core/framework/function.h" #include "tensorflow/core/framework/function.pb.h" #include "tensorflow/core/framework/graph_def_util.h" @@ -36,9 +26,14 @@ limitations under the License. #include "tensorflow/core/framework/node_def_util.h" #include "tensorflow/core/framework/op_def.pb.h" #include "tensorflow/core/framework/versions.pb.h" +#include "tensorflow/core/grappler/grappler_item.h" +#include "tensorflow/core/grappler/op_types.h" +#include "tensorflow/core/grappler/utils.h" +#include "tensorflow/core/grappler/utils/functions.h" +#include "tensorflow/core/lib/gtl/map_util.h" namespace tensorflow { -namespace grappler { +namespace grappler { namespace { static constexpr const char* const kCallOp = "Call"; @@ -46,88 +41,19 @@ static constexpr const char* const kRetOp = "Return"; static constexpr const char* const kIdentityOp = "Identity"; static constexpr const char* const kIdentityNOp = "IdentityN"; static constexpr const char* const kMergeOp = "Merge"; +static constexpr const char* const kGradientOp = + FunctionLibraryDefinition::kGradientOp; +static constexpr const char* const kFuncAttrName = + FunctionLibraryDefinition::kFuncAttr; +static constexpr const char* kNoInlineAttr = "_noinline"; + +bool AttrIsTrue(const FunctionDef& func, const string& attr) { + return func.attr().count(attr) != 0 && func.attr().at(attr).b(); +} -struct FuncInfo { - gtl::ArraySlice fetch; - std::vector inputs; - std::vector input_def; - std::vector outputs; - std::vector output_def; -}; - -// same with commit b691c0 (possibly) -class FunctionInliningContext { - public: - explicit FunctionInliningContext(const GrapplerItem& item) - : item_(&item), function_library_(OpRegistry::Global(), item.graph.library()), functions_(InliningCandidates(item)) {} - - const FunctionLibraryDefinition& Library() const { return function_library_; } - - bool HasInlinedFunctions() const { return !functions_.empty(); } - - // Find inlining candidate by name. Return nullptr if not found. - const FunctionDef* FindInlinedFunction(const string& name) const { - auto it = functions_.find(name); - if (it != functions_.end()) { - return it->second; - } else { - return nullptr; - } - } - - const int graph_version() const { - return item_->graph.versions().producer(); - } - - private: - std::unordered_map InliningCandidates(const GrapplerItem& item) const { - std::unordered_map functions; - - - // outputFile << "In inliningcandidates " << SummarizeGraphDef(item.graph)<< std::endl; - for (const FunctionDef& func : item.graph.library().function()) { - // outputFile << func.signature().name() << std::endl; - // Don't inline functions marked as noinline - // if (func.attr().count("_noinline") != 0) { - // continue; - // } - - std::unique_ptr fbody; - Status stat = FunctionDefToBodyHelper( - func, AttrSlice(&func.attr()), &function_library_, &fbody); - - fbody = SymbolicGradient(*fbody); - - - - printf("Gradient graph %s\n\n",SummarizeGraphDef((fbody)->graph->ToGraphDefDebug()).c_str()); - - // Don't touch anything marked XLA to prevent XLA failures further down - // the road. - if (func.attr().count("_XlaCompile") > 0 && - func.attr().at("_XlaCompile").b()) { - continue; - } - // Can't create IdentityN nodes with no input or output: skip these - // functions for now. - if (func.signature().input_arg_size() == 0 || - func.signature().output_arg_size() == 0) { - continue; - } - functions[func.signature().name()] = &func; - } - // outputFile << "Returning!"< functions_; - const GrapplerItem* item_; - TF_DISALLOW_COPY_AND_ASSIGN(FunctionInliningContext); -}; - - -constexpr const char* const kFuncAttr = FunctionLibraryDefinition::kFuncAttr; +bool MarkedNoInline(const FunctionDef& func) { + return AttrIsTrue(func, kNoInlineAttr); +} // There are two ways of calling a Tensorflow function: // @@ -145,31 +71,11 @@ bool IsDirectFunctionCall(const FunctionDef& func, const NodeDef& func_node) { // Check if func_node has function attribute with a function name matching // FunctionDef signature. bool IsIndirectFunctionCall(const FunctionDef& func, const NodeDef& func_node) { - if (!IsPartitionedCall(func_node) && !IsStatefulPartitionedCall(func_node)) { - return false; - } - - auto* func_attr = AttrSlice(func_node).Find(kFuncAttr); + auto* func_attr = AttrSlice(func_node).Find(kFuncAttrName); return func_attr != nullptr && func_attr->has_func() && func_attr->func().name() == func.signature().name(); } -AttrSlice FunctionInstantiationAttributes(const FunctionDef& func, - const NodeDef& func_node) { - if (IsDirectFunctionCall(func, func_node)) { - return AttrSlice(func_node); - - } else if (IsIndirectFunctionCall(func, func_node)) { - auto* func_attr = AttrSlice(func_node).Find(kFuncAttr); - return AttrSlice(&func_attr->func().attr()); - - } else { - LOG(WARNING) << "Can't resolve function instantiation attributes: " - << SummarizeNodeDef(func_node); - return AttrSlice(); - } -} - // Copy input/output argument type to the type_list. Return error if argument // type is not explicitly defined, and not specified in function attributes. Status CopyArgType(const OpDef::ArgDef& arg, @@ -205,20 +111,130 @@ Status CopyArgType(const OpDef::ArgDef& arg, return OkStatus(); } + +AttrSlice FunctionInstantiationAttributes(const FunctionDef& func, + const NodeDef& func_node) { + if (IsDirectFunctionCall(func, func_node)) { + return AttrSlice(func_node); + + } else if (IsIndirectFunctionCall(func, func_node)) { + auto* func_attr = AttrSlice(func_node).Find(kFuncAttrName); + return AttrSlice(&func_attr->func().attr()); + + } else { + LOG(WARNING) << "Can't resolve function instantiation attributes: " + << SummarizeNodeDef(func_node); + return AttrSlice(); + } +} + +struct FuncInfo { + DataTypeVector arg_types; + DataTypeVector ret_types; + std::vector args; + std::vector rets; +}; + +struct FuncGradInfo { + FuncInfo f; + FuncInfo g; +}; + +// same with commit a9a3b98 (possibly) +class FunctionInliningContext { + public: + explicit FunctionInliningContext(const GrapplerItem& item) + : function_library_(FunctionLibraryDefinition(OpRegistry::Global(), + item.graph.library())) { + InitializeInlinedFunctions(item); + InitializeFetchNodes(item); + } + + + const FunctionLibraryDefinition& FunctionLibrary() const { + return function_library_; + } + + bool HasInlinedFunctions() const { return !inlined_functions_.empty(); } + + bool IsInlinedFunction(const string& name) const { + return inlined_functions_.count(name) > 0; + } + + // Find inlining candidate by name. Return nullptr if not found. + const FunctionDef* FindInlinedFunction(const string& name) const { + return gtl::FindWithDefault(inlined_functions_, name, nullptr); + } + + bool IsFetchNode(const string& node_name) const { + return fetch_nodes_.find(node_name) != fetch_nodes_.end(); + } + + const FunctionDef* FindInlinedFunctionAndGradient(const string& name) const { + string grad_name = strings::StrCat(name, "Grad"); + return FindInlinedFunction(grad_name); + } + + private: + void InitializeInlinedFunctions(const GrapplerItem& item) { + for (const FunctionDef& func : item.graph.library().function()) { + + bool marked_noinline = MarkedNoInline(func); + // Don't inline functions marked as noinline + if (marked_noinline) { + continue; + } + // Don't touch anything marked XLA to prevent XLA failures further down + // the road. + if (func.attr().count("_XlaCompile") > 0 && + func.attr().at("_XlaCompile").b()) { + continue; + } + // Can't create IdentityN nodes with no input or output: skip these + // functions for now. + if (func.signature().input_arg_size() == 0 || + func.signature().output_arg_size() == 0) { + continue; + } + inlined_functions_[func.signature().name()] = &func; + } + } + + void InitializeFetchNodes(const GrapplerItem& item) { + for (const string& fetch : item.fetch) { + fetch_tensors_.insert(fetch); + fetch_nodes_.insert(NodeName(fetch)); + } + } + + FunctionLibraryDefinition function_library_; + std::unordered_map inlined_functions_; + gtl::FlatSet fetch_tensors_; // format: node_name:port + gtl::FlatSet fetch_nodes_; // format: node_name + + TF_DISALLOW_COPY_AND_ASSIGN(FunctionInliningContext); +}; + struct CallInfo { - int call_id; - NodeDef* node; - string node_name; - string function_name; - string device; - std::vector input_nodes; - AttrSlice attr; + int call_id; + string call_frame; + NodeDef* fcall = nullptr; + NodeDef* gcall = nullptr; + bool hasGradient() const { return (gcall != nullptr); } +}; + +struct TransformationResult { + int call_id; + string call_frame; + NodeDef* transformed_node; + std::vector call_nodes; + std::vector ret_nodes; }; class CallRewriter { public: - explicit CallRewriter(const GrapplerItem item_, GraphDef* graph_, const FunctionInliningContext& ctx_) + explicit CallRewriter(const GrapplerItem& item_, GraphDef* graph_, const FunctionInliningContext& ctx_) : graph(graph_), ctx(ctx_), item(item_) { } ~CallRewriter() { @@ -227,139 +243,80 @@ class CallRewriter { Status CollectCalls(std::vector& calls); - Status TransformCall(CallInfo& call_info); + Status TransformCall(const CallInfo& call_info); // Inlines a function to item.graph and if already inlined provide func_info - Status FindCompatibleOrInlineFunction(const string& name, - const AttrSlice& func_attr, - const string& device, - GraphDef* optimized_graph, FuncInfo& func_info); - - void Flush() { - if (!nodes_to_delete.empty()) { - // garbage collect the transformed call nodes - int last = graph->node_size() - 1; - for (int i = graph->node_size() - 1; i >= 0; --i) { - const NodeDef& node = graph->node(i); - if (nodes_to_delete.find(node.name()) != nodes_to_delete.end()) { - graph->mutable_node()->SwapElements(i,last); - last--; - } - } + Status FindCompatibleOrInlineFunction(const CallInfo& call, + GraphDef* optimized_graph, + FuncGradInfo& func_info); - graph->mutable_node()->DeleteSubrange(last + 1, - graph->node_size() - last - 1); - - nodes_to_delete.clear(); - } - - if (!output_map_.empty()) { - // change all the recorded outputs; - // the new outputs where produced by the addition of the RetOp and - // the substitution was deferred to increase performance - for (NodeDef& node : *graph->mutable_node()) { - for (string& in : *node.mutable_input()) { - auto it = output_map_.find(in); - if (it != output_map_.end()) { - in = it->second; - } - } - } - output_map_.clear(); - } - } + void Flush(); inline int GetCallId(const NodeDef& node) { int call_id = id; id++; return call_id; } private: - Status AddCallOp(const CallInfo& call_info, const OpDef::ArgDef arg, - const string& input, int arg_id, NodeDef* call_node); - - Status AddRetOp(const CallInfo& call_info, const OpDef::ArgDef arg, - const string& input, int arg_id, NodeDef* ret_node); - - Status ConnectInput(NodeDef* from, NodeDef* to); - - bool ShouldPreserveOutputs(const string& node) { - for (const string& fetch_out : item.fetch) { - if (NodeName(fetch_out) == node) - return true; - } - return false; - } + Status TransformNode(const CallInfo& info, + NodeDef* call, const FuncInfo& f, + std::vector& call_nodes, + std::vector& ret_nodes); void ReplaceOutput(const string& old_output, const string& new_output) { // maybe some more checks output_map_[old_output] = new_output; } - void MarkCallTransformed(CallInfo& call_info) { - NodeDef* node = call_info.node; - node->clear_input(); - node->set_op("NoOp"); - node->set_name(AddPrefixToNodeName(node->name(), "$MarkToDelete$")); - nodes_to_delete.insert(node->name()); + void MarkCallTransformed(const CallInfo& call_info) { + CHECK_NOTNULL(call_info.fcall); + MarkNodeDelete(call_info.fcall); + + if (call_info.gcall != nullptr) { + MarkNodeDelete(call_info.gcall); + } + } + + void MarkTransformed(TransformationResult& result) { + NodeDef* n = result.transformed_node; + CHECK_NOTNULL(n); + transformed_calls_[result.transformed_node->name()] = result; + n->clear_input(); + n->set_op("NoOp"); + n->set_name(AddPrefixToNodeName(n->name(), "$MarkToDelete$")); + nodes_to_delete.insert(n->name()); + } + + void MarkNodeDelete(NodeDef* n) { + n->clear_input(); + n->set_op("NoOp"); + n->set_name(AddPrefixToNodeName(n->name(), "$MarkToDelete$")); + nodes_to_delete.insert(n->name()); } GraphDef* graph; const FunctionInliningContext& ctx; - const GrapplerItem item; - std::unordered_map transformed_functions_; + const GrapplerItem& item; + std::unordered_map transformed_functions_; std::unordered_map output_map_; + std::unordered_map transformed_calls_; std::set nodes_to_delete; int id = 0; TF_DISALLOW_COPY_AND_ASSIGN(CallRewriter); }; - -Status CallRewriter::CollectCalls(std::vector& calls) { - - // identify and collect calls in the graph - // outputFile << "In collect calls: "<< std::endl; - for (NodeDef& node : *graph->mutable_node()) { - const FunctionDef* func = ctx.FindInlinedFunction(node.op()); - // printf("Collecting Calls: %s %s \n", node.name().c_str(), node.op().c_str()); - if (func != nullptr) { - CallInfo call; - call.call_id = GetCallId(node); - call.node_name = node.name(); - call.function_name = node.op(); - call.node = &node; - call.device = node.device(); - call.attr = FunctionInstantiationAttributes(*func, node); - - int input_size = func->signature().input_arg_size(); - call.input_nodes.resize(input_size); - for (int i = 0; i < input_size; i++) { - call.input_nodes[i] = node.input(i); - } - calls.push_back(call); - } - } - return OkStatus(); -} - -Status CallRewriter::AddCallOp(const CallInfo& call_info, - const OpDef::ArgDef arg, +Status AddCallOp(const CallInfo& call_info, + const DataType& type, const string& input, + const string& prefix, int arg_id, NodeDef* call) { - string prefix = call_info.node_name; string call_name = strings::StrCat("Call", "_", arg_id); call->set_op(kCallOp); call->set_name(AddPrefixToNodeName(call_name, prefix)); //call->set_device(node.device()); call->add_input(input); - DataType type; - TF_RETURN_IF_ERROR(CopyArgType(arg, call_info.attr, &type)); - auto& attr = *call->mutable_attr(); - - //SetArgType(arg, call_info.attr, attr); - attr["T"].set_type(type); - attr["frame_name"].set_s(call_info.function_name); + attr["frame_name"].set_s(call_info.call_frame); attr["call_id"].set_i(call_info.call_id); attr["arg_id"].set_i(arg_id); attr["is_constant"].set_b(false); @@ -367,29 +324,26 @@ Status CallRewriter::AddCallOp(const CallInfo& call_info, return OkStatus(); } -Status CallRewriter::AddRetOp(const CallInfo& call_info, - const OpDef::ArgDef arg, +Status AddRetOp(const CallInfo& call_info, + const DataType& type, const string& input, + const string& prefix, int arg_id, NodeDef* ret) { - string prefix = call_info.node_name; string ret_name = strings::StrCat("Ret", "_", arg_id); ret->set_op(kRetOp); ret->set_name(AddPrefixToNodeName(ret_name, prefix)); ret->add_input(input); - DataType type; - TF_RETURN_IF_ERROR(CopyArgType(arg, call_info.attr, &type)); - auto& attr = *ret->mutable_attr(); attr["T"].set_type(type); - attr["frame_name"].set_s(call_info.function_name); + attr["frame_name"].set_s(call_info.call_frame); attr["call_id"].set_i(call_info.call_id); attr["arg_id"].set_i(arg_id); return OkStatus(); } -Status CallRewriter::ConnectInput(NodeDef* from, NodeDef* to) { +Status ConnectInput(NodeDef* from, NodeDef* to) { int to_input = to->input_size(); if (to_input == 1) { // it is Identity and we convert it to Merge. @@ -403,154 +357,69 @@ Status CallRewriter::ConnectInput(NodeDef* from, NodeDef* to) { return OkStatus(); } -Status CallRewriter::TransformCall(CallInfo& call_info) { - FuncInfo func_info; - - // inlines the body of a function and provides a struct with func_info - TF_RETURN_IF_ERROR(FindCompatibleOrInlineFunction( - call_info.function_name, call_info.attr, call_info.device, graph, func_info)); - - CHECK_EQ(call_info.input_nodes.size(), func_info.inputs.size()); - - std::vector call_nodes; - std::vector ret_nodes; - - call_nodes.resize(func_info.inputs.size()); - for (unsigned int arg_num = 0; arg_num < func_info.inputs.size(); arg_num++) { - call_nodes[arg_num] = graph->add_node(); - TF_CHECK_OK(AddCallOp(call_info, - func_info.input_def[arg_num], - call_info.input_nodes[arg_num], - arg_num, - call_nodes[arg_num])); - - call_nodes[arg_num]->set_device(call_info.device); - - // connect the input of the inlined function to feed from call. - TF_RETURN_IF_ERROR(ConnectInput(call_nodes[arg_num], func_info.inputs[arg_num])); - } - - ret_nodes.resize(func_info.outputs.size()); - for (unsigned int out_port = 0; out_port < func_info.outputs.size(); out_port++) { - ret_nodes[out_port] = graph->add_node(); - TF_CHECK_OK(AddRetOp(call_info, - func_info.output_def[out_port], - func_info.outputs[out_port], - out_port, - ret_nodes[out_port])); - ret_nodes[out_port]->set_device(call_info.device); - } - - // for each call create a control dependency to each return - // to facilitate dead propagation semantics - for (NodeDef* ret : ret_nodes) { - for (NodeDef* call : call_nodes) - *(ret->add_input()) = AsControlDependency(call->name()); - } - - if (ShouldPreserveOutputs(call_info.node_name)) { - // create an IdentityN with the same name of the initial function call - // so as to preserve the naming of the outputs. - // we re-use the initial node and we change (a) the op to IdentityN and - // (b) the inputs to point to the outputs of the ret_nodes - // The other information such as types, device placement etc remain the same. - // The IdentityN node will sync the outputs and therefore may result to performance degradation. - NodeDef* out = graph->add_node(); - out->set_op(kIdentityNOp); - out->set_name(call_info.node_name); - out->set_device(call_info.device); - AttrValue::ListValue* type_list = (*out->mutable_attr())["T"].mutable_list(); - for (const OpDef::ArgDef& arg : func_info.output_def) { - TF_RETURN_IF_ERROR(CopyArgType(arg, call_info.attr, type_list)); - } - for (unsigned int i = 0; i < func_info.outputs.size(); i++) { - *out->add_input() = ret_nodes[i]->name(); - } - } else { - for (unsigned int out_port = 0; out_port < func_info.outputs.size(); out_port++) { - ReplaceOutput(strings::StrCat(call_info.node_name, ":", out_port), ret_nodes[out_port]->name()); - } - if (func_info.outputs.size() == 1) { - ReplaceOutput(call_info.node_name, ret_nodes[0]->name()); - } - } - printf("Mark call %s (function %s) as transformed\n", call_info.node_name.c_str(), call_info.function_name.c_str()); - MarkCallTransformed(call_info); - - return OkStatus(); -} - Status InlineFunction(const FunctionDef& func_def, + const AttrSlice& func_instantiation_attr, const FunctionInliningContext& ctx, - const AttrSlice& func_attr, const string& device, - GraphDef* graph, FuncInfo& func_info) { - // std::unique_ptr item = GrapplerItemFromFunctionDef(func_def, func_attr, ctx.Library()); - GrapplerFunctionItem fitem; - GrapplerFunctionItem* item = &fitem; - TF_RETURN_IF_ERROR(MakeGrapplerFunctionItem( - func_def, func_attr, ctx.Library(), ctx.graph_version(), item)); + GraphDef* graph, FuncGradInfo& func_info) { + GrapplerFunctionItem item; + const int graph_version = graph->versions().producer(); + TF_RETURN_IF_ERROR(MakeGrapplerFunctionItem(func_def, func_instantiation_attr, ctx.FunctionLibrary(), graph_version, &item)); string prefix = func_def.signature().name(); - - if (!item) { - return errors::InvalidArgument( - "Failed to inline function ", func_def.signature().name()); - } int arg_size = func_def.signature().input_arg_size(); // create an inverse map of arg to provide name -> argument number std::unordered_map input_nodes; for (int i = 0; i < arg_size; ++i) { - const OpDef::ArgDef& arg = func_def.signature().input_arg(i); - input_nodes[arg.name()] = i; + const OpDef::ArgDef& input_arg = func_def.signature().input_arg(i); + input_nodes[input_arg.name()] = i; } - func_info.inputs.resize(arg_size); - func_info.input_def.resize(arg_size); + func_info.f.args.resize(arg_size); + func_info.f.arg_types.resize(arg_size); for (int i = 0; i < arg_size; ++i) { - const OpDef::ArgDef& arg = func_def.signature().input_arg(i); + const OpDef::ArgDef& input_arg = func_def.signature().input_arg(i); NodeDef* merge = graph->add_node(); merge->set_name(AddPrefixToNodeName(strings::StrCat("Input", "_", i), prefix)); merge->set_op(kIdentityOp); merge->set_device(device); - + DataType type; - TF_RETURN_IF_ERROR(CopyArgType(arg, func_attr, &type)); + TF_RETURN_IF_ERROR(CopyArgType(input_arg, func_instantiation_attr, &type)); auto& attr = *merge->mutable_attr(); attr["T"].set_type(type); - func_info.inputs[i] = merge; - func_info.input_def[i] = arg; + func_info.f.args[i] = merge; + func_info.f.arg_types[i] = type; } // prefix each node in function graph and place it to the global graph. // the inputs of each node need to be renamed as well to reflect the change. - for (NodeDef& func_body_node : *item->graph.mutable_node()) { + for (NodeDef& func_body_node : *item.mutable_function_body().mutable_node()) { const string& curr_name = func_body_node.name(); // If the func body node is func's input argument auto input_it = input_nodes.find(curr_name); if (input_it != input_nodes.end()) { CHECK_EQ(0, func_body_node.input_size()); + // If the func body node is func's input argument // Turn input placeholders into identity nodes - if (IsArg(func_body_node)) { - func_body_node.set_op(kIdentityOp); - } + func_body_node.set_op(kIdentityOp); + (*func_body_node.mutable_attr())["T"] = func_body_node.attr().at("dtype"); + func_body_node.mutable_attr()->erase("dtype"); + func_body_node.mutable_attr()->erase("shape"); // Connect merge with input arg - func_body_node.add_input(func_info.inputs[input_it->second]->name()); + int idx = input_nodes[curr_name]; + func_body_node.add_input(func_info.f.args[idx]->name()); } else { // Else if not an input_arg_node // Update the input names if any. for (string& input : *func_body_node.mutable_input()) { input = AddPrefixToNodeName(input, prefix); } - // If this is a return node, change the op to KIdentityOp - if(IsRetval(func_body_node)){ - func_body_node.set_op(kIdentityOp); - } // If the node has no input, make hook it up to the Merge nodes to ensure // it runs in the same frame as the other nodes of the function body. if (func_body_node.input_size() == 0) { - for (auto& func_input_node : func_info.inputs) { + for (auto& func_input_node : func_info.f.args) { *func_body_node.add_input() = AsControlDependency(func_input_node->name()); } } @@ -567,25 +436,324 @@ Status InlineFunction(const FunctionDef& func_def, graph->add_node()->Swap(&func_body_node); } - func_info.outputs.clear(); - func_info.outputs.resize(item->fetch.size()); - func_info.output_def.resize(item->fetch.size()); + func_info.f.rets.clear(); + func_info.f.rets.resize(item.fetch.size()); + func_info.f.ret_types.resize(item.fetch.size()); + + std::vector fetch = item.fetch; + for (unsigned int i = 0; i < fetch.size(); i++) { + const OutputArgInstantiation& output_arg = item.output(i); + func_info.f.rets[i] = AddPrefixToNodeName(output_arg.node_name, prefix); + func_info.f.ret_types[i] = output_arg.data_type; + } + + return OkStatus(); +} + +Status InlineFunctionAndGradient(const FunctionDef& fdef, + const AttrSlice& func_instantiation_attr, + const FunctionInliningContext& ctx, + const string& device, + GraphDef* graph, + FuncGradInfo& func_info) { + // Get func_def's gradient graph + const FunctionDef* fgdef = ctx.FindInlinedFunctionAndGradient(fdef.signature().name()); + if (fgdef == nullptr) { + return errors::InvalidArgument( + "Invalid argument, function ", fgdef->signature().name(), "can not be found", + "or not marked to be inlined"); + } + + GrapplerFunctionItem item; + const int graph_version = graph->versions().producer(); + TF_RETURN_IF_ERROR(MakeGrapplerFunctionItem(*fgdef, func_instantiation_attr, ctx.FunctionLibrary(), graph_version, &item)); + + string prefix = fdef.signature().name(); + size_t farg_size = fdef.signature().input_arg_size(); + size_t fret_size = fdef.signature().output_arg_size(); + size_t garg_size = fgdef->signature().input_arg_size() - farg_size; + size_t gret_size = fgdef->signature().output_arg_size() - fret_size; + + CHECK_EQ(farg_size, gret_size); + CHECK_EQ(garg_size, fret_size); + + func_info.f.arg_types.resize(farg_size); + func_info.g.arg_types.resize(farg_size + garg_size); + func_info.g.ret_types.resize(farg_size); + for (int i = 0; i < farg_size; i++) { + const OpDef::ArgDef& input_arg = fdef.signature().input_arg(i); + func_info.f.arg_types[i] = input_arg.type(); + func_info.g.arg_types[i] = input_arg.type(); + func_info.g.ret_types[i] = input_arg.type(); + } + + func_info.f.ret_types.resize(fret_size); + for (int i = 0; i < fret_size; i++) { + const OutputArgInstantiation& output_arg = item.output(i); + func_info.f.ret_types[i] = output_arg.data_type; + func_info.g.arg_types[farg_size + i] = output_arg.data_type; + } + + // create an inverse map of arg to provide name -> argument number + std::unordered_map input_map; + std::vector input_names; + input_names.resize(farg_size); + for (int i = 0; i < farg_size + garg_size; ++i) { + const OpDef::ArgDef& input_arg = fdef.signature().input_arg(i); + input_map[input_arg.name()] = i; + if (i < farg_size) { + input_names[i] = input_arg.name(); + } + } + func_info.f.args.resize(farg_size); + func_info.f.rets.resize(fret_size); + func_info.g.args.resize(farg_size + garg_size); + func_info.g.rets.resize(gret_size); + + // prefix each node in function graph and place it to the global graph. + // the inputs of each node need to be renamed as well to reflect the change. + for (NodeDef& n : *item.mutable_function_body().mutable_node()) { + // If the func body node is func's input argument + auto input_it = input_map.find(n.name()); + bool is_input = input_it != input_map.end(); + + if (is_input) { + CHECK_EQ(0, n.input_size()); + n.set_op(kIdentityOp); + (*n.mutable_attr())["T"] = n.attr().at("dtype"); + n.mutable_attr()->erase("dtype"); + n.mutable_attr()->erase("shape"); + } + + // Add the node name as a prefix to avoid collisions after inlining + n.set_name(AddPrefixToNodeName(n.name(), prefix)); + // Update the input names if any. + for (string& input : *n.mutable_input()) { + input = AddPrefixToNodeName(input, prefix); + } + + // Make sure the node is placed + if (n.device().empty()) + n.set_device(device); + + if (n.op() == kGradientOp) { + auto& attr = *n.mutable_attr(); + auto& n_ = attr["_n"].s(); + attr["_n"].set_s(AddPrefixToNodeName(n_, prefix)); + } + + // If the node has no input, make hook it up to the Merge nodes to ensure + // it runs in the same frame as the other nodes of the function body. + if (!is_input && n.input_size() == 0) { + // CHECK: constants from both in function and gradient are connected + // with the inputs of the function only. + for (const string& arg : input_names) { + *n.add_input() = AsControlDependency(AddPrefixToNodeName(arg, prefix)); + } + } + + // Move the node to the main graph + NodeDef* nn = graph->add_node(); + nn->Swap(&n); + + if (is_input) { + int i = input_it->second; + if (i < farg_size) { + func_info.f.args[i] = nn; + func_info.g.args[i] = func_info.f.args[i]; + } else { + func_info.g.args[i] = nn; + } + } + } + + CHECK_EQ(fret_size + gret_size, item.fetch.size()); + + for (unsigned int i = 0; i < fret_size + gret_size; i++) { + const OutputArgInstantiation& output_arg = item.output(i); + string output_port = AddPrefixToNodeName(output_arg.node_name, prefix); + if (i < fret_size) { + func_info.f.rets[i] = output_port; + } else { + func_info.g.rets[i - fret_size] = output_port; + } + } + + return OkStatus(); +} + +Status CallRewriter::CollectCalls(std::vector& calls) { + + std::unordered_map call_map; + std::vector gradients; - for (unsigned int i = 0; i < item->fetch.size(); i++) { - func_info.outputs[i] = AddPrefixToNodeName(item->fetch[i], prefix); - func_info.output_def[i] = func_def.signature().output_arg(i); + // identify and collect calls in the graph + for (NodeDef& node : *graph->mutable_node()) { + if (node.op() == kGradientOp) { + gradients.push_back(&node); + } else { + const FunctionDef* func_def = ctx.FindInlinedFunction(node.op()); + if (func_def != nullptr) { + CallInfo& call = call_map[node.name()]; + call.call_id = GetCallId(node); + call.call_frame = node.op(); + call.fcall = &node; + } + } + } + for (NodeDef* gcall : gradients) { + if (gcall->attr().count("_n") > 0) { + const string& n = gcall->attr().at("_n").s(); + + auto fcall_it = call_map.find(n); + if (fcall_it == call_map.end()) { + return errors::InvalidArgument("Cannot find forward node for gradient ", + gcall->name()); + } + CallInfo& call = fcall_it->second; + call.gcall = gcall; + } } + for (const auto& it : call_map) { + calls.push_back(it.second); + } + return OkStatus(); +} + +Status CallRewriter::TransformNode(const CallInfo& info, + NodeDef* call, + const FuncInfo& f, + std::vector& call_nodes, + std::vector& ret_nodes) { + CHECK_EQ(call->input_size(), f.args.size()); + + call_nodes.resize(f.args.size()); + for (unsigned int i = 0; i < f.args.size(); i++) { + /* check if call node is already in place, if so, validate and skip */ + if (call_nodes[i] != nullptr) { + // TODO: validate call_id + // TODO: validate input + //CHECK_EQ(call_nodes[i]->input(0), call->input(i)); + } else { + call_nodes[i] = graph->add_node(); + TF_CHECK_OK(AddCallOp(info, + f.arg_types[i], + call->input(i), + call->name(), + i, + call_nodes[i])); + + call_nodes[i]->set_device(call->device()); + + // connect the input of the inlined function to feed from call. + TF_RETURN_IF_ERROR(ConnectInput(call_nodes[i], f.args[i])); + } + } + + // check for control edges in call + gtl::FlatSet control_inputs; + for (const string& input : call->input()) { + if (IsControlInput(input)) { + control_inputs.insert(NodeName(input)); + } + } + + for (NodeDef* call_node : call_nodes) { + for (const string& control_input : control_inputs) + *(call_node->add_input()) = AsControlDependency(control_input); + } + + ret_nodes.resize(f.rets.size()); + for (unsigned int i = 0; i < f.rets.size(); i++) { + if (ret_nodes[i] != nullptr) { + // TODO: validate call_id + // CHECK_EQ(ret_nodes[i]->input(0), f.rets[i]); + } else { + ret_nodes[i] = graph->add_node(); + TF_CHECK_OK(AddRetOp(info, + f.ret_types[i], + f.rets[i], + call->name(), + i, + ret_nodes[i])); + ret_nodes[i]->set_device(call->device()); + } + } + + if (ctx.IsFetchNode(call->name())) { + // create an IdentityN with the same name of the initial function call + // so as to preserve the naming of the outputs. + // we re-use the initial node and we change (a) the op to IdentityN and + // (b) the inputs to point to the outputs of the ret_nodes + // The other information such as types, device placement etc remain the same. + // The IdentityN node will sync the outputs and therefore may result to performance degradation. + NodeDef* out = graph->add_node(); + out->set_op(kIdentityNOp); + out->set_name(call->name()); + out->set_device(call->device()); + AttrValue::ListValue* type_list = (*out->mutable_attr())["T"].mutable_list(); + for (const DataType& type : f.ret_types) { + type_list->add_type(type); + } + for (unsigned int i = 0; i < f.rets.size(); i++) { + *out->add_input() = ret_nodes[i]->name(); + } + } else { + for (unsigned int i = 0; i < f.rets.size(); i++) { + ReplaceOutput(strings::StrCat(call->name(), ":", i), ret_nodes[i]->name()); + } +// if (f.rets.size() == 1) { + ReplaceOutput(call->name(), ret_nodes[0]->name()); +// } + } + + // for each call create a control dependency to each return + // to facilitate dead propagation semantics + for (NodeDef* ret : ret_nodes) { + for (NodeDef* call : call_nodes) + // TODO: Check if there is already a control dependency. + *(ret->add_input()) = AsControlDependency(call->name()); + } + + return OkStatus(); +} + +Status CallRewriter::TransformCall(const CallInfo& call_info) { + FuncGradInfo func_info; + TransformationResult result; + + // inlines the body of a function and provides a struct with func_info + TF_RETURN_IF_ERROR(FindCompatibleOrInlineFunction(call_info, graph, func_info)); + + result.call_id = call_info.call_id; + result.call_frame = call_info.call_frame; + result.transformed_node = call_info.fcall; + + TF_RETURN_IF_ERROR(TransformNode(call_info, call_info.fcall, func_info.f, result.call_nodes, result.ret_nodes)); + MarkTransformed(result); + + if (call_info.hasGradient()) { + TransformationResult grad_result; + grad_result.call_id = call_info.call_id; + grad_result.call_frame = call_info.call_frame; + grad_result.transformed_node = call_info.gcall; + grad_result.call_nodes = result.call_nodes; + // keep all the inputs of the function + TF_RETURN_IF_ERROR(TransformNode(call_info, call_info.gcall, func_info.g, grad_result.call_nodes, grad_result.ret_nodes)); + MarkTransformed(grad_result); + } + MarkCallTransformed(call_info); return OkStatus(); } -// new Status CallRewriter::FindCompatibleOrInlineFunction( - const string& func_name, - const AttrSlice& func_attr, - const string& device, + const CallInfo& call, GraphDef* graph, - FuncInfo& func_info) { + FuncGradInfo& func_info) { + CHECK_NOTNULL(call.fcall); + const string& func_name = call.fcall->op(); + string device = call.fcall->device(); const auto& it = transformed_functions_.find(func_name); // maybe it is not wise to discard call attributes // possible type specialization? @@ -599,25 +767,80 @@ Status CallRewriter::FindCompatibleOrInlineFunction( "Invalid argument, function ", func_name, "can not be found", "or not marked to be inlined"); } - TF_RETURN_IF_ERROR( - InlineFunction(*func_def, ctx, func_attr, device, graph, func_info)); + + const AttrSlice func_instantiation_attr = + FunctionInstantiationAttributes(*func_def, *call.fcall); + + if (call.hasGradient()) { + TF_RETURN_IF_ERROR( + InlineFunctionAndGradient(*func_def, func_instantiation_attr, ctx, device, graph, func_info)); + } else { + TF_RETURN_IF_ERROR( + InlineFunction(*func_def, func_instantiation_attr, ctx, device, graph, func_info)); + } transformed_functions_[func_name] = func_info; printf("Store inlined function %s\n", func_name.c_str()); return OkStatus(); } +void CallRewriter::Flush() { + + if (!transformed_calls_.empty()) { + // garbage collect the transformed call nodes + int last = graph->node_size() - 1; + for (int i = graph->node_size() - 1; i >= 0; --i) { + const NodeDef& node = graph->node(i); + if (nodes_to_delete.find(node.name()) != nodes_to_delete.end()) { + graph->mutable_node()->SwapElements(i,last); + last--; + } + } + graph->mutable_node()->DeleteSubrange(last + 1, + graph->node_size() - last - 1); + } + if (!output_map_.empty()) { + for (NodeDef& node : *graph->mutable_node()) { + std::vector control_nodes; + int last = node.input_size() - 1; + + for (int i = node.input_size() - 1; i >= 0; --i) { + string& in = *node.mutable_input(i); + auto it = output_map_.find(in); + if (it != output_map_.end()) { + in = it->second; + } + if (IsControlInput(in)) { + auto it = transformed_calls_.find(NodeName(in)); + if (it != transformed_calls_.end()) { + node.mutable_input()->SwapElements(i, last); + control_nodes.push_back(it->second); + last--; + } + } + node.mutable_input()->DeleteSubrange(last + 1, + node.input_size() - last - 1); + for (TransformationResult& result : control_nodes) { + for (NodeDef* ret_node : result.ret_nodes) { + *node.add_input() = AsControlDependency(ret_node->name()); + } + } + } + } + } + transformed_calls_.clear(); + nodes_to_delete.clear(); + output_map_.clear(); +} + } // namespace Status FunctionTransformation::Optimize(Cluster* cluster, const GrapplerItem& item, GraphDef* output) { - - FunctionInliningContext ctx(item); CallRewriter call_rewriter(item, output, ctx); *output = item.graph; if (!ctx.HasInlinedFunctions()) { - // outputFile << "No inlining functions!"<name().c_str(), SummarizeGraphDef(*output).c_str()); } calls.clear(); call_rewriter.Flush(); } call_rewriter.Flush(); - // outputFile.close(); printf("After finalizing:\n %s\n", SummarizeGraphDef(*output).c_str()); *output->mutable_versions() = item.graph.versions(); @@ -678,4 +900,4 @@ Status FunctionTransformation::Optimize(Cluster* cluster, const GrapplerItem& it } } // end namespace grappler -} // end namespace tensorflow +} // end namespace tensorflow \ No newline at end of file From 96e919c2e787d3462f72245eb0c262ec56e448a8 Mon Sep 17 00:00:00 2001 From: George Vasilakopoulos Date: Thu, 6 Jun 2024 15:48:54 +0000 Subject: [PATCH 42/53] Debugging --- .../grappler/optimizers/function_transformation.cc | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/tensorflow/core/grappler/optimizers/function_transformation.cc b/tensorflow/core/grappler/optimizers/function_transformation.cc index faad2f658775f1..5950559287a124 100644 --- a/tensorflow/core/grappler/optimizers/function_transformation.cc +++ b/tensorflow/core/grappler/optimizers/function_transformation.cc @@ -404,9 +404,6 @@ Status InlineFunction(const FunctionDef& func_def, // If the func body node is func's input argument // Turn input placeholders into identity nodes func_body_node.set_op(kIdentityOp); - (*func_body_node.mutable_attr())["T"] = func_body_node.attr().at("dtype"); - func_body_node.mutable_attr()->erase("dtype"); - func_body_node.mutable_attr()->erase("shape"); // Connect merge with input arg int idx = input_nodes[curr_name]; func_body_node.add_input(func_info.f.args[idx]->name()); @@ -416,6 +413,11 @@ Status InlineFunction(const FunctionDef& func_def, for (string& input : *func_body_node.mutable_input()) { input = AddPrefixToNodeName(input, prefix); } + // If this is a return node, change the op to KIdentityOp + if(IsRetval(func_body_node)){ + func_body_node.set_op(kIdentityOp); + } + // If the node has no input, make hook it up to the Merge nodes to ensure // it runs in the same frame as the other nodes of the function body. if (func_body_node.input_size() == 0) { @@ -520,9 +522,6 @@ Status InlineFunctionAndGradient(const FunctionDef& fdef, if (is_input) { CHECK_EQ(0, n.input_size()); n.set_op(kIdentityOp); - (*n.mutable_attr())["T"] = n.attr().at("dtype"); - n.mutable_attr()->erase("dtype"); - n.mutable_attr()->erase("shape"); } // Add the node name as a prefix to avoid collisions after inlining @@ -541,6 +540,9 @@ Status InlineFunctionAndGradient(const FunctionDef& fdef, auto& n_ = attr["_n"].s(); attr["_n"].set_s(AddPrefixToNodeName(n_, prefix)); } + if(IsRetval(n)){ + n.set_op(kIdentityOp); + } // If the node has no input, make hook it up to the Merge nodes to ensure // it runs in the same frame as the other nodes of the function body. From b8d49c016c6339e58315bb929c9007abe2acde98 Mon Sep 17 00:00:00 2001 From: George Vasilakopoulos Date: Thu, 6 Jun 2024 17:06:47 +0000 Subject: [PATCH 43/53] Prevent SegFault on error; Migration changes --- .../core/grappler/optimizers/function_transformation.cc | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/tensorflow/core/grappler/optimizers/function_transformation.cc b/tensorflow/core/grappler/optimizers/function_transformation.cc index 5950559287a124..7c4a9c9693b571 100644 --- a/tensorflow/core/grappler/optimizers/function_transformation.cc +++ b/tensorflow/core/grappler/optimizers/function_transformation.cc @@ -462,7 +462,7 @@ Status InlineFunctionAndGradient(const FunctionDef& fdef, const FunctionDef* fgdef = ctx.FindInlinedFunctionAndGradient(fdef.signature().name()); if (fgdef == nullptr) { return errors::InvalidArgument( - "Invalid argument, function ", fgdef->signature().name(), "can not be found", + "Invalid argument, gradient of function ", fdef.signature().name(), "can not be found", "or not marked to be inlined"); } @@ -604,8 +604,9 @@ Status CallRewriter::CollectCalls(std::vector& calls) { } } for (NodeDef* gcall : gradients) { - if (gcall->attr().count("_n") > 0) { - const string& n = gcall->attr().at("_n").s(); + if (gcall->attr().count("f") > 0) { + printf("Debug string: %s \n\n", gcall->attr().at("f").DebugString().c_str()); + const string& n = gcall->attr().at("f").func().name(); auto fcall_it = call_map.find(n); if (fcall_it == call_map.end()) { From 6868634ae0099b41b0bc2d1e6d76544759efde65 Mon Sep 17 00:00:00 2001 From: George Vasilakopoulos Date: Tue, 11 Jun 2024 16:35:05 +0000 Subject: [PATCH 44/53] Debugging --- tensorflow/core/common_runtime/function.cc | 18 +++- .../optimizers/function_transformation.cc | 101 ++++++++++++------ 2 files changed, 81 insertions(+), 38 deletions(-) diff --git a/tensorflow/core/common_runtime/function.cc b/tensorflow/core/common_runtime/function.cc index 3789b94a20757e..8cab5632b8f29d 100644 --- a/tensorflow/core/common_runtime/function.cc +++ b/tensorflow/core/common_runtime/function.cc @@ -45,6 +45,7 @@ limitations under the License. #include "tensorflow/core/framework/op.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/versions.pb.h" +#include "tensorflow/core/framework/graph_def_util.h" #include "tensorflow/core/graph/algorithm.h" #include "tensorflow/core/graph/control_flow.h" #include "tensorflow/core/graph/node_builder.h" @@ -1541,17 +1542,24 @@ std::unique_ptr SymbolicGradientHelper::Compute() { g)); // Remove the old return nodes from the function body. - for (Node* n : gbody->ret_nodes) { - g->RemoveNode(n); - } - gbody->ret_types = fbody_->arg_types; + // for (Node* n : gbody->ret_nodes) { + // g->RemoveNode(n); + // } + // gbody->ret_types = fbody_->arg_types; + + // Concatenate vectors + gbody->ret_types.insert(gbody->ret_types.end(), fbody_->arg_types.begin(), fbody_->arg_types.end()); + + printf("After adding gradients:\n", SummarizeGraphDef(g->ToGraphDefDebug()).c_str()); + + // TODO(apassos): use the right dtype for gradients of resource variables for (int i = 0; i < gbody->ret_types.size(); ++i) { if (gbody->ret_types[i] == DT_RESOURCE) { gbody->ret_types[i] = DT_FLOAT; } } - gbody->ret_nodes.clear(); + // gbody->ret_nodes.clear(); // Add new return nodes to the function gradient body for each node // in 'x_grad_nodes'. const int arg_types_size = static_cast(fbody_->arg_types.size()); diff --git a/tensorflow/core/grappler/optimizers/function_transformation.cc b/tensorflow/core/grappler/optimizers/function_transformation.cc index 7c4a9c9693b571..2e08acbedc4eed 100644 --- a/tensorflow/core/grappler/optimizers/function_transformation.cc +++ b/tensorflow/core/grappler/optimizers/function_transformation.cc @@ -26,6 +26,8 @@ limitations under the License. #include "tensorflow/core/framework/node_def_util.h" #include "tensorflow/core/framework/op_def.pb.h" #include "tensorflow/core/framework/versions.pb.h" +#include "tensorflow/core/common_runtime/function_def_utils.h" +#include "tensorflow/core/common_runtime/function.h" #include "tensorflow/core/grappler/grappler_item.h" #include "tensorflow/core/grappler/op_types.h" #include "tensorflow/core/grappler/utils.h" @@ -155,6 +157,13 @@ class FunctionInliningContext { return function_library_; } + Status AddFunctionDef(const FunctionDef& fdef) { + TF_RETURN_IF_ERROR(function_library_.AddFunctionDef(fdef)); + inlined_functions_[fdef.signature().name()] = function_library_.Find(fdef.signature().name()); + return OkStatus(); + } + + bool HasInlinedFunctions() const { return !inlined_functions_.empty(); } bool IsInlinedFunction(const string& name) const { @@ -234,7 +243,7 @@ struct TransformationResult { class CallRewriter { public: - explicit CallRewriter(const GrapplerItem& item_, GraphDef* graph_, const FunctionInliningContext& ctx_) + explicit CallRewriter(const GrapplerItem& item_, GraphDef* graph_, FunctionInliningContext& ctx_) : graph(graph_), ctx(ctx_), item(item_) { } ~CallRewriter() { @@ -292,7 +301,7 @@ class CallRewriter { } GraphDef* graph; - const FunctionInliningContext& ctx; + FunctionInliningContext& ctx; const GrapplerItem& item; std::unordered_map transformed_functions_; std::unordered_map output_map_; @@ -454,33 +463,59 @@ Status InlineFunction(const FunctionDef& func_def, Status InlineFunctionAndGradient(const FunctionDef& fdef, const AttrSlice& func_instantiation_attr, - const FunctionInliningContext& ctx, + FunctionInliningContext& ctx, const string& device, GraphDef* graph, FuncGradInfo& func_info) { // Get func_def's gradient graph - const FunctionDef* fgdef = ctx.FindInlinedFunctionAndGradient(fdef.signature().name()); - if (fgdef == nullptr) { - return errors::InvalidArgument( - "Invalid argument, gradient of function ", fdef.signature().name(), "can not be found", - "or not marked to be inlined"); - } + + // const FunctionDef* fgdef = ctx.FindInlinedFunctionAndGradient(fdef.signature().name()); + // if(fgdef == nullptr){ + // std::unique_ptr fbody; + // TF_RETURN_IF_ERROR(FunctionDefToBodyHelper(fdef,AttrSlice(&fdef.attr()),&ctx.FunctionLibrary(), &fbody)); - GrapplerFunctionItem item; - const int graph_version = graph->versions().producer(); - TF_RETURN_IF_ERROR(MakeGrapplerFunctionItem(*fgdef, func_instantiation_attr, ctx.FunctionLibrary(), graph_version, &item)); + // printf("Original graph %s\n\n",SummarizeGraphDef((fbody)->graph->ToGraphDefDebug()).c_str()); + + // fbody = SymbolicGradient(*fbody); + + // printf("Gradient graph %s\n\n",SummarizeGraphDef((fbody)->graph->ToGraphDefDebug()).c_str()); + + // FunctionDef graddef = fbody->record->fdef(); + // printf("Debug string: %s\n",graddef.DebugString().c_str()); + // std::string grad_name = strings::StrCat(fdef.signature().name(), "Grad"); + // graddef.mutable_signature()->set_name(grad_name); + // TF_RETURN_IF_ERROR(ctx.AddFunctionDef(graddef)); + // fgdef = ctx.FindInlinedFunctionAndGradient(fdef.signature().name()); + // } + GraphDef grad_graph; + std::unique_ptr fbody; + TF_RETURN_IF_ERROR(FunctionDefToBodyHelper(fdef,AttrSlice(&fdef.attr()),&ctx.FunctionLibrary(), &fbody)); + + printf("Original graph %s\n\n",SummarizeGraphDef((fbody)->graph->ToGraphDefDebug()).c_str()); + + fbody = SymbolicGradient(*fbody); + + printf("Gradient graph %s\n\n",SummarizeGraphDef((fbody)->graph->ToGraphDefDebug()).c_str()); + + fbody->graph->ToGraphDef(&grad_graph); + + + + // GrapplerFunctionItem item; + const int graph_version = fbody->graph->versions().producer(); + // TF_RETURN_IF_ERROR(MakeGrapplerFunctionItem(*fgdef, func_instantiation_attr, ctx.FunctionLibrary(), graph_version, &item)); string prefix = fdef.signature().name(); size_t farg_size = fdef.signature().input_arg_size(); size_t fret_size = fdef.signature().output_arg_size(); - size_t garg_size = fgdef->signature().input_arg_size() - farg_size; - size_t gret_size = fgdef->signature().output_arg_size() - fret_size; + size_t garg_size = fbody->arg_nodes.size();// - farg_size; + size_t gret_size = fbody->ret_nodes.size() - fret_size; CHECK_EQ(farg_size, gret_size); - CHECK_EQ(garg_size, fret_size); + CHECK_EQ(garg_size, fret_size + farg_size); func_info.f.arg_types.resize(farg_size); - func_info.g.arg_types.resize(farg_size + garg_size); + func_info.g.arg_types.resize(garg_size); func_info.g.ret_types.resize(farg_size); for (int i = 0; i < farg_size; i++) { const OpDef::ArgDef& input_arg = fdef.signature().input_arg(i); @@ -491,30 +526,29 @@ Status InlineFunctionAndGradient(const FunctionDef& fdef, func_info.f.ret_types.resize(fret_size); for (int i = 0; i < fret_size; i++) { - const OutputArgInstantiation& output_arg = item.output(i); - func_info.f.ret_types[i] = output_arg.data_type; - func_info.g.arg_types[farg_size + i] = output_arg.data_type; + // const OutputArgInstantiation& output_arg = item.output(i); + func_info.f.ret_types[i] = fbody->ret_types[i]; + func_info.g.arg_types[farg_size + i] = fbody->ret_types[i]; } // create an inverse map of arg to provide name -> argument number std::unordered_map input_map; std::vector input_names; input_names.resize(farg_size); - for (int i = 0; i < farg_size + garg_size; ++i) { - const OpDef::ArgDef& input_arg = fdef.signature().input_arg(i); - input_map[input_arg.name()] = i; + for (int i = 0; i < garg_size; ++i) { + input_map[fbody->arg_nodes[i]->name()] = i; if (i < farg_size) { - input_names[i] = input_arg.name(); + input_names[i] = fbody->arg_nodes[i]->name(); } } func_info.f.args.resize(farg_size); func_info.f.rets.resize(fret_size); - func_info.g.args.resize(farg_size + garg_size); + func_info.g.args.resize(garg_size); func_info.g.rets.resize(gret_size); // prefix each node in function graph and place it to the global graph. // the inputs of each node need to be renamed as well to reflect the change. - for (NodeDef& n : *item.mutable_function_body().mutable_node()) { + for (NodeDef& n : *grad_graph.mutable_node()) { // If the func body node is func's input argument auto input_it = input_map.find(n.name()); bool is_input = input_it != input_map.end(); @@ -535,11 +569,11 @@ Status InlineFunctionAndGradient(const FunctionDef& fdef, if (n.device().empty()) n.set_device(device); - if (n.op() == kGradientOp) { - auto& attr = *n.mutable_attr(); - auto& n_ = attr["_n"].s(); - attr["_n"].set_s(AddPrefixToNodeName(n_, prefix)); - } + // if (n.op() == kGradientOp) { + // auto& attr = *n.mutable_attr(); + // std::string& name = *attr.at("f").mutable_func()->mutable_name(); + // name = AddPrefixToNodeName(name, prefix); + // } if(IsRetval(n)){ n.set_op(kIdentityOp); } @@ -569,11 +603,10 @@ Status InlineFunctionAndGradient(const FunctionDef& fdef, } } - CHECK_EQ(fret_size + gret_size, item.fetch.size()); + CHECK_EQ(fret_size + gret_size, fbody->arg_nodes.size()); for (unsigned int i = 0; i < fret_size + gret_size; i++) { - const OutputArgInstantiation& output_arg = item.output(i); - string output_port = AddPrefixToNodeName(output_arg.node_name, prefix); + string output_port = AddPrefixToNodeName(fbody->ret_nodes[i]->name(), prefix); if (i < fret_size) { func_info.f.rets[i] = output_port; } else { @@ -843,6 +876,8 @@ Status FunctionTransformation::Optimize(Cluster* cluster, const GrapplerItem& it CallRewriter call_rewriter(item, output, ctx); *output = item.graph; + + printf("Before optimizer: %s\n\n",SummarizeGraphDef(*output).c_str()); if (!ctx.HasInlinedFunctions()) { return OkStatus(); } From 04af8f0c38a88b99d590b55e9d303d4fa8583efc Mon Sep 17 00:00:00 2001 From: George Vasilakopoulos Date: Thu, 13 Jun 2024 17:12:49 +0000 Subject: [PATCH 45/53] Updated Test file --- recursion-tests/exponents.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/recursion-tests/exponents.py b/recursion-tests/exponents.py index 8ef9f320b3c8ea..f107979ec154a3 100644 --- a/recursion-tests/exponents.py +++ b/recursion-tests/exponents.py @@ -17,8 +17,8 @@ @function.Defun(tf.float32, tf.int32, func_name="EXPONENT", out_names=["ret"]) def ExpImpl(x, n): return tf.cond(tf.equal(n,0), - lambda: tf.cast(tf.constant(1),tf.float32), - lambda: x*x) + lambda: tf.constant(1.0), + lambda: x*exp(x,n-1)) # @function.Defun(tf.int32, func_name="Fac", out_names=["ret"]) @@ -34,7 +34,7 @@ def ExpImpl(x, n): x = tf.compat.v1.get_variable('n_var', [], initializer=tf.constant_initializer(4.0)) y = ExpImpl(x,2) -train_op = tf.compat.v1.train.GradientDescentOptimizer(0.01).minimize(y) +train_op = tf.compat.v1.train.GradientDescentOptimizer(0.1).minimize(y) print(tf.compat.v1.get_default_graph().as_graph_def()) From b681bd25def0fd20a905c712afcb67cb3d17e9d1 Mon Sep 17 00:00:00 2001 From: George Vasilakopoulos Date: Thu, 13 Jun 2024 17:13:20 +0000 Subject: [PATCH 46/53] Debugging --- .../optimizers/function_transformation.cc | 31 ++++++++++--------- 1 file changed, 17 insertions(+), 14 deletions(-) diff --git a/tensorflow/core/grappler/optimizers/function_transformation.cc b/tensorflow/core/grappler/optimizers/function_transformation.cc index 2e08acbedc4eed..543678d8a43747 100644 --- a/tensorflow/core/grappler/optimizers/function_transformation.cc +++ b/tensorflow/core/grappler/optimizers/function_transformation.cc @@ -509,26 +509,28 @@ Status InlineFunctionAndGradient(const FunctionDef& fdef, size_t farg_size = fdef.signature().input_arg_size(); size_t fret_size = fdef.signature().output_arg_size(); size_t garg_size = fbody->arg_nodes.size();// - farg_size; - size_t gret_size = fbody->ret_nodes.size() - fret_size; + size_t gret_size = fbody->ret_nodes.size();// - fret_size; - CHECK_EQ(farg_size, gret_size); + CHECK_EQ(farg_size, gret_size - fret_size); CHECK_EQ(garg_size, fret_size + farg_size); func_info.f.arg_types.resize(farg_size); func_info.g.arg_types.resize(garg_size); - func_info.g.ret_types.resize(farg_size); + func_info.g.ret_types.resize(gret_size); for (int i = 0; i < farg_size; i++) { const OpDef::ArgDef& input_arg = fdef.signature().input_arg(i); func_info.f.arg_types[i] = input_arg.type(); func_info.g.arg_types[i] = input_arg.type(); - func_info.g.ret_types[i] = input_arg.type(); } func_info.f.ret_types.resize(fret_size); - for (int i = 0; i < fret_size; i++) { + for (int i = 0; i < gret_size; i++) { // const OutputArgInstantiation& output_arg = item.output(i); - func_info.f.ret_types[i] = fbody->ret_types[i]; - func_info.g.arg_types[farg_size + i] = fbody->ret_types[i]; + if(i < fret_size){ + func_info.f.ret_types[i] = fbody->ret_types[i]; + func_info.g.arg_types[farg_size + i] = fbody->ret_types[i]; + } + func_info.g.ret_types[i] = fbody->ret_types[i]; } // create an inverse map of arg to provide name -> argument number @@ -603,15 +605,14 @@ Status InlineFunctionAndGradient(const FunctionDef& fdef, } } - CHECK_EQ(fret_size + gret_size, fbody->arg_nodes.size()); + CHECK_EQ(gret_size, fbody->arg_nodes.size()); - for (unsigned int i = 0; i < fret_size + gret_size; i++) { + for (unsigned int i = 0; i < gret_size; i++) { string output_port = AddPrefixToNodeName(fbody->ret_nodes[i]->name(), prefix); if (i < fret_size) { func_info.f.rets[i] = output_port; - } else { - func_info.g.rets[i - fret_size] = output_port; } + func_info.g.rets[i] = output_port; } return OkStatus(); @@ -629,7 +630,7 @@ Status CallRewriter::CollectCalls(std::vector& calls) { } else { const FunctionDef* func_def = ctx.FindInlinedFunction(node.op()); if (func_def != nullptr) { - CallInfo& call = call_map[node.name()]; + CallInfo& call = call_map[node.op()]; call.call_id = GetCallId(node); call.call_frame = node.op(); call.fcall = &node; @@ -643,8 +644,9 @@ Status CallRewriter::CollectCalls(std::vector& calls) { auto fcall_it = call_map.find(n); if (fcall_it == call_map.end()) { - return errors::InvalidArgument("Cannot find forward node for gradient ", - gcall->name()); + // return errors::InvalidArgument("Cannot find forward node for gradient ", + // gcall->name()); + continue; } CallInfo& call = fcall_it->second; call.gcall = gcall; @@ -775,6 +777,7 @@ Status CallRewriter::TransformCall(const CallInfo& call_info) { grad_result.call_frame = call_info.call_frame; grad_result.transformed_node = call_info.gcall; grad_result.call_nodes = result.call_nodes; + grad_result.ret_nodes = result.ret_nodes; // keep all the inputs of the function TF_RETURN_IF_ERROR(TransformNode(call_info, call_info.gcall, func_info.g, grad_result.call_nodes, grad_result.ret_nodes)); MarkTransformed(grad_result); From 746700f261c7c26c614d3e2238ae1a0d4ba29aca Mon Sep 17 00:00:00 2001 From: George Vasilakopoulos Date: Thu, 13 Jun 2024 17:13:48 +0000 Subject: [PATCH 47/53] Enable function_optimizer --- tensorflow/core/grappler/optimizers/meta_optimizer.cc | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tensorflow/core/grappler/optimizers/meta_optimizer.cc b/tensorflow/core/grappler/optimizers/meta_optimizer.cc index 21f666d509539d..f2a1e80e4f09e5 100644 --- a/tensorflow/core/grappler/optimizers/meta_optimizer.cc +++ b/tensorflow/core/grappler/optimizers/meta_optimizer.cc @@ -350,9 +350,9 @@ Status MetaOptimizer::InitializeOptimizers( USER_IS_EXPERIMENTAL_BOTH(function_optimization)) { VLOG(2) << "function_optimization is not implemented in TFG yet"; } else { - // optimizers->push_back(std::make_unique( - // cfg_.function_optimization(), - // /*lower_control_flow=*/LowerControlFlow())); + optimizers->push_back(std::make_unique( + cfg_.function_optimization(), + /*lower_control_flow=*/LowerControlFlow())); } } if (BOTH_NOT_OFF(common_subgraph_elimination) && From d0e98a62a9e56319071ac28b0b15b9df27a39384 Mon Sep 17 00:00:00 2001 From: George Vasilakopoulos Date: Tue, 18 Jun 2024 14:41:49 +0000 Subject: [PATCH 48/53] Remove pending check --- tensorflow/core/common_runtime/graph_constructor.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow/core/common_runtime/graph_constructor.cc b/tensorflow/core/common_runtime/graph_constructor.cc index 0bba0bf4881dce..1d2a7707c1cee3 100644 --- a/tensorflow/core/common_runtime/graph_constructor.cc +++ b/tensorflow/core/common_runtime/graph_constructor.cc @@ -587,7 +587,7 @@ void GraphConstructor::UpdatePendingCountAndReady(int processed, int* current_pending_count = &pending_count_[output]; if (*current_pending_count == 0 && is_function_call) continue; if (*current_pending_count == 0 && merge_node_indices_.count(output) == 1) continue; - CHECK_GT(*current_pending_count, 0); + // CHECK_GT(*current_pending_count, 0); (*current_pending_count)--; if (*current_pending_count == 0) { ready_.insert(output); From 9286920557880c2ad7c7410cac17f114f5846617 Mon Sep 17 00:00:00 2001 From: George Vasilakopoulos Date: Tue, 18 Jun 2024 14:43:42 +0000 Subject: [PATCH 49/53] Distinguish between regular and gradient call/return nodes --- .../optimizers/function_transformation.cc | 30 ++++++++++++------- 1 file changed, 19 insertions(+), 11 deletions(-) diff --git a/tensorflow/core/grappler/optimizers/function_transformation.cc b/tensorflow/core/grappler/optimizers/function_transformation.cc index 543678d8a43747..f03afdf1aae692 100644 --- a/tensorflow/core/grappler/optimizers/function_transformation.cc +++ b/tensorflow/core/grappler/optimizers/function_transformation.cc @@ -267,7 +267,8 @@ class CallRewriter { Status TransformNode(const CallInfo& info, NodeDef* call, const FuncInfo& f, std::vector& call_nodes, - std::vector& ret_nodes); + std::vector& ret_nodes, + bool is_gradient_node); void ReplaceOutput(const string& old_output, const string& new_output) { // maybe some more checks @@ -316,7 +317,7 @@ Status AddCallOp(const CallInfo& call_info, const DataType& type, const string& input, const string& prefix, - int arg_id, NodeDef* call) { + int arg_id, NodeDef* call, bool is_gradient_call = false) { string call_name = strings::StrCat("Call", "_", arg_id); call->set_op(kCallOp); call->set_name(AddPrefixToNodeName(call_name, prefix)); @@ -329,6 +330,7 @@ Status AddCallOp(const CallInfo& call_info, attr["call_id"].set_i(call_info.call_id); attr["arg_id"].set_i(arg_id); attr["is_constant"].set_b(false); + attr["is_gradient"].set_b(is_gradient_call); return OkStatus(); } @@ -337,7 +339,7 @@ Status AddRetOp(const CallInfo& call_info, const DataType& type, const string& input, const string& prefix, - int arg_id, NodeDef* ret) { + int arg_id, NodeDef* ret, bool is_gradient_return = false) { string ret_name = strings::StrCat("Ret", "_", arg_id); ret->set_op(kRetOp); ret->set_name(AddPrefixToNodeName(ret_name, prefix)); @@ -348,6 +350,7 @@ Status AddRetOp(const CallInfo& call_info, attr["frame_name"].set_s(call_info.call_frame); attr["call_id"].set_i(call_info.call_id); attr["arg_id"].set_i(arg_id); + attr["is_gradient"].set_b(is_gradient_return); return OkStatus(); } @@ -663,7 +666,7 @@ Status CallRewriter::TransformNode(const CallInfo& info, NodeDef* call, const FuncInfo& f, std::vector& call_nodes, - std::vector& ret_nodes) { + std::vector& ret_nodes, bool is_gradient_node = false) { CHECK_EQ(call->input_size(), f.args.size()); call_nodes.resize(f.args.size()); @@ -680,7 +683,8 @@ Status CallRewriter::TransformNode(const CallInfo& info, call->input(i), call->name(), i, - call_nodes[i])); + call_nodes[i], + is_gradient_node)); call_nodes[i]->set_device(call->device()); @@ -714,7 +718,8 @@ Status CallRewriter::TransformNode(const CallInfo& info, f.rets[i], call->name(), i, - ret_nodes[i])); + ret_nodes[i], + is_gradient_node)); ret_nodes[i]->set_device(call->device()); } } @@ -741,17 +746,20 @@ Status CallRewriter::TransformNode(const CallInfo& info, for (unsigned int i = 0; i < f.rets.size(); i++) { ReplaceOutput(strings::StrCat(call->name(), ":", i), ret_nodes[i]->name()); } -// if (f.rets.size() == 1) { + if (f.rets.size() == 1) { ReplaceOutput(call->name(), ret_nodes[0]->name()); -// } + } } // for each call create a control dependency to each return // to facilitate dead propagation semantics for (NodeDef* ret : ret_nodes) { - for (NodeDef* call : call_nodes) + for (NodeDef* call : call_nodes){ + if(ret->attr().at("is_gradient").b() != call->attr().at("is_gradient").b()) continue; + printf("Adding control edge from %s to %s\n",call->name().c_str(),ret->name().c_str()); // TODO: Check if there is already a control dependency. *(ret->add_input()) = AsControlDependency(call->name()); + } } return OkStatus(); @@ -768,7 +776,7 @@ Status CallRewriter::TransformCall(const CallInfo& call_info) { result.call_frame = call_info.call_frame; result.transformed_node = call_info.fcall; - TF_RETURN_IF_ERROR(TransformNode(call_info, call_info.fcall, func_info.f, result.call_nodes, result.ret_nodes)); + TF_RETURN_IF_ERROR(TransformNode(call_info, call_info.fcall, func_info.f, result.call_nodes, result.ret_nodes,false)); MarkTransformed(result); if (call_info.hasGradient()) { @@ -779,7 +787,7 @@ Status CallRewriter::TransformCall(const CallInfo& call_info) { grad_result.call_nodes = result.call_nodes; grad_result.ret_nodes = result.ret_nodes; // keep all the inputs of the function - TF_RETURN_IF_ERROR(TransformNode(call_info, call_info.gcall, func_info.g, grad_result.call_nodes, grad_result.ret_nodes)); + TF_RETURN_IF_ERROR(TransformNode(call_info, call_info.gcall, func_info.g, grad_result.call_nodes, grad_result.ret_nodes,true)); MarkTransformed(grad_result); } MarkCallTransformed(call_info); From 85172e60e58dc35f3b8d9adebb7e0303f85ae26c Mon Sep 17 00:00:00 2001 From: George Vasilakopoulos Date: Wed, 19 Jun 2024 01:52:58 +0000 Subject: [PATCH 50/53] Replace correctly the outputs of SymbolicGradient nodes --- .../optimizers/function_transformation.cc | 17 ++++++++++++----- 1 file changed, 12 insertions(+), 5 deletions(-) diff --git a/tensorflow/core/grappler/optimizers/function_transformation.cc b/tensorflow/core/grappler/optimizers/function_transformation.cc index f03afdf1aae692..610b757374a523 100644 --- a/tensorflow/core/grappler/optimizers/function_transformation.cc +++ b/tensorflow/core/grappler/optimizers/function_transformation.cc @@ -669,6 +669,8 @@ Status CallRewriter::TransformNode(const CallInfo& info, std::vector& ret_nodes, bool is_gradient_node = false) { CHECK_EQ(call->input_size(), f.args.size()); + unsigned int next_return_node = is_gradient_node ? ret_nodes.size() : 0; + call_nodes.resize(f.args.size()); for (unsigned int i = 0; i < f.args.size(); i++) { /* check if call node is already in place, if so, validate and skip */ @@ -743,12 +745,10 @@ Status CallRewriter::TransformNode(const CallInfo& info, *out->add_input() = ret_nodes[i]->name(); } } else { - for (unsigned int i = 0; i < f.rets.size(); i++) { - ReplaceOutput(strings::StrCat(call->name(), ":", i), ret_nodes[i]->name()); + for (unsigned int i = next_return_node; i < f.rets.size(); i++) { + ReplaceOutput(strings::StrCat(call->name(), ":", i - next_return_node), ret_nodes[i]->name()); + if(i == next_return_node)ReplaceOutput(call->name(), ret_nodes[i]->name()); } - if (f.rets.size() == 1) { - ReplaceOutput(call->name(), ret_nodes[0]->name()); - } } // for each call create a control dependency to each return @@ -845,6 +845,13 @@ void CallRewriter::Flush() { graph->mutable_node()->DeleteSubrange(last + 1, graph->node_size() - last - 1); } + + + // for(auto& p : output_map_){ + // printf("%s -> %s\n",p.first.c_str(),p.second.c_str()); + + // } + if (!output_map_.empty()) { for (NodeDef& node : *graph->mutable_node()) { std::vector control_nodes; From 148299d9b11865a8f6bf388827741141449fc594 Mon Sep 17 00:00:00 2001 From: George Vasilakopoulos Date: Sun, 23 Jun 2024 01:07:30 +0000 Subject: [PATCH 51/53] Create gradient graph through python APIs --- .../optimizers/function_transformation.cc | 95 +++++++++++-------- tensorflow/python/framework/function.py | 85 +++++++++++++++-- tensorflow/python/framework/ops.py | 12 +-- tensorflow/python/ops/gradients_impl.py | 5 +- tensorflow/python/ops/gradients_util.py | 8 +- 5 files changed, 146 insertions(+), 59 deletions(-) diff --git a/tensorflow/core/grappler/optimizers/function_transformation.cc b/tensorflow/core/grappler/optimizers/function_transformation.cc index 610b757374a523..1cc92780cef45c 100644 --- a/tensorflow/core/grappler/optimizers/function_transformation.cc +++ b/tensorflow/core/grappler/optimizers/function_transformation.cc @@ -188,6 +188,8 @@ class FunctionInliningContext { void InitializeInlinedFunctions(const GrapplerItem& item) { for (const FunctionDef& func : item.graph.library().function()) { + printf("Func name %s\n",func.signature().name().c_str()); + bool marked_noinline = MarkedNoInline(func); // Don't inline functions marked as noinline if (marked_noinline) { @@ -472,47 +474,25 @@ Status InlineFunctionAndGradient(const FunctionDef& fdef, FuncGradInfo& func_info) { // Get func_def's gradient graph - // const FunctionDef* fgdef = ctx.FindInlinedFunctionAndGradient(fdef.signature().name()); - // if(fgdef == nullptr){ - // std::unique_ptr fbody; - // TF_RETURN_IF_ERROR(FunctionDefToBodyHelper(fdef,AttrSlice(&fdef.attr()),&ctx.FunctionLibrary(), &fbody)); - - // printf("Original graph %s\n\n",SummarizeGraphDef((fbody)->graph->ToGraphDefDebug()).c_str()); - - // fbody = SymbolicGradient(*fbody); - - // printf("Gradient graph %s\n\n",SummarizeGraphDef((fbody)->graph->ToGraphDefDebug()).c_str()); - - // FunctionDef graddef = fbody->record->fdef(); - // printf("Debug string: %s\n",graddef.DebugString().c_str()); - // std::string grad_name = strings::StrCat(fdef.signature().name(), "Grad"); - // graddef.mutable_signature()->set_name(grad_name); - // TF_RETURN_IF_ERROR(ctx.AddFunctionDef(graddef)); - // fgdef = ctx.FindInlinedFunctionAndGradient(fdef.signature().name()); - // } - GraphDef grad_graph; - std::unique_ptr fbody; - TF_RETURN_IF_ERROR(FunctionDefToBodyHelper(fdef,AttrSlice(&fdef.attr()),&ctx.FunctionLibrary(), &fbody)); - - printf("Original graph %s\n\n",SummarizeGraphDef((fbody)->graph->ToGraphDefDebug()).c_str()); - - fbody = SymbolicGradient(*fbody); - - printf("Gradient graph %s\n\n",SummarizeGraphDef((fbody)->graph->ToGraphDefDebug()).c_str()); - - fbody->graph->ToGraphDef(&grad_graph); + const FunctionDef* fgdef = ctx.FindInlinedFunctionAndGradient(fdef.signature().name()); + if (fgdef == nullptr) { + return errors::InvalidArgument( + "Invalid argument, gradient of function ", fdef.signature().name(), "can not be found", + "or not marked to be inlined"); + } - // GrapplerFunctionItem item; - const int graph_version = fbody->graph->versions().producer(); - // TF_RETURN_IF_ERROR(MakeGrapplerFunctionItem(*fgdef, func_instantiation_attr, ctx.FunctionLibrary(), graph_version, &item)); + + GrapplerFunctionItem item; + const int graph_version = graph->versions().producer(); + TF_RETURN_IF_ERROR(MakeGrapplerFunctionItem(*fgdef, func_instantiation_attr, ctx.FunctionLibrary(), graph_version, &item)); string prefix = fdef.signature().name(); size_t farg_size = fdef.signature().input_arg_size(); size_t fret_size = fdef.signature().output_arg_size(); - size_t garg_size = fbody->arg_nodes.size();// - farg_size; - size_t gret_size = fbody->ret_nodes.size();// - fret_size; + size_t garg_size = fgdef->signature().input_arg_size();// - farg_size; + size_t gret_size = fgdef->signature().output_arg_size();// - fret_size; CHECK_EQ(farg_size, gret_size - fret_size); CHECK_EQ(garg_size, fret_size + farg_size); @@ -530,10 +510,10 @@ Status InlineFunctionAndGradient(const FunctionDef& fdef, for (int i = 0; i < gret_size; i++) { // const OutputArgInstantiation& output_arg = item.output(i); if(i < fret_size){ - func_info.f.ret_types[i] = fbody->ret_types[i]; - func_info.g.arg_types[farg_size + i] = fbody->ret_types[i]; + func_info.f.ret_types[i] = item.output(i).data_type; + func_info.g.arg_types[farg_size + i] = item.output(i).data_type; } - func_info.g.ret_types[i] = fbody->ret_types[i]; + func_info.g.ret_types[i] = item.output(i).data_type; } // create an inverse map of arg to provide name -> argument number @@ -541,9 +521,9 @@ Status InlineFunctionAndGradient(const FunctionDef& fdef, std::vector input_names; input_names.resize(farg_size); for (int i = 0; i < garg_size; ++i) { - input_map[fbody->arg_nodes[i]->name()] = i; + input_map[item.input(i).node_name] = i; if (i < farg_size) { - input_names[i] = fbody->arg_nodes[i]->name(); + input_names[i] = item.input(i).node_name; } } func_info.f.args.resize(farg_size); @@ -553,7 +533,7 @@ Status InlineFunctionAndGradient(const FunctionDef& fdef, // prefix each node in function graph and place it to the global graph. // the inputs of each node need to be renamed as well to reflect the change. - for (NodeDef& n : *grad_graph.mutable_node()) { + for (NodeDef& n : *item.mutable_function_body().mutable_node()) { // If the func body node is func's input argument auto input_it = input_map.find(n.name()); bool is_input = input_it != input_map.end(); @@ -608,10 +588,10 @@ Status InlineFunctionAndGradient(const FunctionDef& fdef, } } - CHECK_EQ(gret_size, fbody->arg_nodes.size()); + CHECK_EQ(gret_size, item.fetch.size()); for (unsigned int i = 0; i < gret_size; i++) { - string output_port = AddPrefixToNodeName(fbody->ret_nodes[i]->name(), prefix); + string output_port = AddPrefixToNodeName(item.output(i).node_name, prefix); if (i < fret_size) { func_info.f.rets[i] = output_port; } @@ -918,7 +898,38 @@ Status FunctionTransformation::Optimize(Cluster* cluster, const GrapplerItem& it call_rewriter.Flush(); } call_rewriter.Flush(); + + printf("After finalizing:\n %s\n", SummarizeGraphDef(*output).c_str()); + + // for (NodeDef& node : *output->mutable_node()){ + + // if(node.op() != kGradientOp)continue; + // NameAttrList func; + // TF_RETURN_IF_ERROR(GetNodeAttr(AttrSlice(node), kFuncAttrName, &func)); + // gradient::Creator creator; + // TF_RETURN_IF_ERROR(gradient::GetOpGradientCreator(func.name(), &creator)); + // if (creator == nullptr) { + // return absl::InvalidArgumentError( + // absl::StrCat("No gradient is defined for ", func.name())); + // } + // FunctionDef grad_fdef; + + // std::unique_ptr* fbody; + // TF_RETURN_IF_ERROR(creator(AttrSlice(&func.attr()), &grad_fdef)); + // TF_RETURN_IF_ERROR(FunctionDefToBodyHelper( + // grad_fdef, AttrSlice(&func.attr()), &ctx.FunctionLibrary(), fbody)); + + // printf("Gradient of of %s:\n%s\n\n",func.name().c_str(),SummarizeGraphDef((*fbody)->graph->ToGraphDefDebug()).c_str()); + + + // } + + + + + + *output->mutable_versions() = item.graph.versions(); // Function Library should be pruned of unreachable function definitions diff --git a/tensorflow/python/framework/function.py b/tensorflow/python/framework/function.py index 8fb6efa0bab7bd..8dc2629c738e11 100644 --- a/tensorflow/python/framework/function.py +++ b/tensorflow/python/framework/function.py @@ -153,6 +153,7 @@ def __init__(self, *input_types, **kwargs): self._func_name = kwargs.pop("func_name", None) self._grad_func = kwargs.pop("grad_func", None) self._python_grad_func = kwargs.pop("python_grad_func", None) + self._create_grad_func = kwargs.pop("create_grad_func", False) self._out_names = kwargs.pop("out_names", None) self._extra_kwargs = kwargs @@ -197,6 +198,8 @@ def __call__(self, func): self._func_name, self._grad_func, self._python_grad_func, + self._create_grad_func, + is_gradient=False, out_names=self._out_names, **self._extra_kwargs) @@ -320,6 +323,8 @@ def __init__(self, func_name=None, grad_func=None, python_grad_func=None, + create_grad_func=False, + is_gradient=False, out_names=None, shape_func=None, capture_by_value=False, @@ -361,6 +366,8 @@ def __init__(self, self._func_name = func_name self._grad_func = grad_func self._python_grad_func = python_grad_func + self._create_grad_func = create_grad_func + self._is_gradient = is_gradient self._out_names = out_names self._shape_func = shape_func self._capture_by_value = capture_by_value @@ -387,12 +394,31 @@ def __init__(self, # is disabled the whole _definition is available and this is simply # another reference to _definition.signature self._op_def = None - + assert isinstance(input_types, (list, tuple)) self._arg_types = input_types self._arg_names = [argnames[i] if i < len(argnames) else ("arg%d" % i) for i in range(len(input_types))] + + self._args = list(zip(self._arg_names,self._arg_types)) + if self._create_grad_func: + grad_func_name = self._func_name #+ "Grad" + out_names = self._out_names.copy() + for (argname, argtype) in self._args: + out_names.append("d" + argname) + # Todo: check if we need to copy all the args so that they don't get passed by reference + self._grad_func = _DefinedFunction(func=func, + argnames=argnames, + input_types=input_types, + func_name=grad_func_name, + grad_func=None, + python_grad_func=None, + create_grad_func=False, + is_gradient=True, + out_names=out_names, + **kwargs) + @property def name(self): """Function name.""" @@ -471,7 +497,8 @@ def _create_definition_if_needed(self): def _create_definition_if_needed_impl(self): """This is not what you want, see _create_definition_if_needed.""" - if self._definition is not None or self._c_func is not None: + if self._definition is not None or self._c_func is not None \ + or (self._is_gradient and not ops.get_default_graph()._is_function(self._func_name)): return # Copy variable collections (by reference) from the parent graph such that @@ -489,18 +516,25 @@ def _create_definition_if_needed_impl(self): self._func, self._arg_names, self._arg_types, + self._out_names, self._func_name, + self._is_gradient, self._capture_by_value, self._caller_device, collections_ref=collections_ref, allowlisted_stateful_ops=self._allowlisted_stateful_ops, - capture_resource_var_by_value=self._capture_resource_var_by_value) + capture_resource_var_by_value=self._capture_resource_var_by_value, + functions=parent_graph._functions + ) self._extra_inputs = temp_graph.extra_inputs # pylint: disable=protected-access self._sub_functions = temp_graph._functions # pylint: enable=protected-access + if self._is_gradient and self._func_name: + self._func_name += "Grad" + # Extra kwargs are treated as attrs on the function def. if self._func_name: base_func_name = self._func_name @@ -1001,7 +1035,9 @@ def _add_op_and_parents(self, op: ops.Operation): def func_graph_from_py_func(func, arg_names, arg_types, + out_names, name=None, + is_gradient=False, capture_by_value=False, device=None, colocation_stack=None, @@ -1009,7 +1045,8 @@ def func_graph_from_py_func(func, collections_ref=None, arg_shapes=None, allowlisted_stateful_ops=None, - capture_resource_var_by_value=True): + capture_resource_var_by_value=True, + functions=None): """Returns a _FuncGraph generated from `func`. Args: @@ -1062,7 +1099,25 @@ def func_graph_from_py_func(func, func_graph.inputs.append(argholder) # Call func and gather the output tensors. with vs.variable_scope("", custom_getter=func_graph.getvar): - outputs = func(*func_graph.inputs) + gradient_out_types = [] + if is_gradient: + name = name + "Grad" + outputs = [func(*func_graph.inputs)] + dinputs = [] + for (out, name) in list(zip(outputs, out_names)): + argholder = array_ops.placeholder(out.op.node_def.attr["T"].type, name="d"+name) + dinputs.append(argholder) + gradient_out_types.append(out.op.node_def.attr["T"].type) + for argtype in arg_types: + gradient_out_types.append(argtype) + from tensorflow.python.ops import gradients_impl + doutputs = gradients_impl.gradients(outputs, func_graph.inputs, dinputs, functions = functions) + if not isinstance(doutputs, list): + doutputs = [doutputs] + outputs.extend(doutputs) + func_graph.inputs.extend(dinputs) + else: + outputs = func(*func_graph.inputs) # There is no way of distinguishing between a function not returning # anything and a function returning None in Python. @@ -1078,10 +1133,24 @@ def func_graph_from_py_func(func, # If func only returned one value, make it a tuple. if not isinstance(outputs, (list, tuple)): outputs = (outputs,) - if any(_ is None for _ in outputs): - raise ValueError(f"Function {name} can not return None.") + # if any(_ is None for _ in outputs): + # raise ValueError(f"Function {name} can not return None.") # Ensures each output is a Tensor in the function graph. - outputs = [ops.convert_to_tensor(t) for t in outputs] + if is_gradient: + tmp_out = [] + for out, out_type in zip(outputs, gradient_out_types): + if out is not None: + tmp_out.append(ops.convert_to_tensor(out)) + else: + if out_type.is_bool: + tmp_out.append(ops.convert_to_tensor(False)) + elif out_type.is_floating: + tmp_out.append(ops.convert_to_tensor(0.0)) + else: + tmp_out.append(ops.convert_to_tensor(0)) + outputs = tmp_out + else: + outputs = [ops.convert_to_tensor(t) for t in outputs] outputs = [func_graph.capture(t) if t.graph is not func_graph else t for t in outputs] func_graph.outputs = outputs diff --git a/tensorflow/python/framework/ops.py b/tensorflow/python/framework/ops.py index b863ee2bd878bb..f300fd1caad572 100644 --- a/tensorflow/python/framework/ops.py +++ b/tensorflow/python/framework/ops.py @@ -2544,12 +2544,12 @@ def _add_function(self, function) -> None: # pylint: disable=protected-access with self._c_graph.get() as c_graph: with function._c_func.get() as func: - if getattr(function, "_grad_func", None): - # For deprecated _DefinedFunction. - with function._grad_func._c_func.get() as gradient: - pywrap_tf_session.TF_GraphCopyFunction(c_graph, func, gradient) - else: - pywrap_tf_session.TF_GraphCopyFunction(c_graph, func, None) + # if getattr(function, "_grad_func", None): + # # For deprecated _DefinedFunction. + # with function._grad_func._c_func.get() as gradient: + # pywrap_tf_session.TF_GraphCopyFunction(c_graph, func, gradient) + # else: + pywrap_tf_session.TF_GraphCopyFunction(c_graph, func, None) # pylint: enable=protected-access self._functions[compat.as_str(name)] = function diff --git a/tensorflow/python/ops/gradients_impl.py b/tensorflow/python/ops/gradients_impl.py index 14e316a6433f6e..187c71f8a7436b 100644 --- a/tensorflow/python/ops/gradients_impl.py +++ b/tensorflow/python/ops/gradients_impl.py @@ -61,7 +61,8 @@ def gradients(ys, gate_gradients=False, aggregation_method=None, stop_gradients=None, - unconnected_gradients=UnconnectedGradients.NONE): + unconnected_gradients=UnconnectedGradients.NONE, + functions = None): """Constructs symbolic derivatives of sum of `ys` w.r.t. x in `xs`. `ys` and `xs` are each a `Tensor` or a list of tensors. `grad_ys` @@ -181,7 +182,7 @@ def gradients(ys, return gradients_util._GradientsHelper( ys, xs, grad_ys, name, colocate_gradients_with_ops, gate_gradients, aggregation_method, stop_gradients, - unconnected_gradients) + unconnected_gradients, functions) # pylint: enable=protected-access diff --git a/tensorflow/python/ops/gradients_util.py b/tensorflow/python/ops/gradients_util.py index fa568ea706cf36..e4647c5adc1776 100644 --- a/tensorflow/python/ops/gradients_util.py +++ b/tensorflow/python/ops/gradients_util.py @@ -512,6 +512,7 @@ def _GradientsHelper(ys, aggregation_method=None, stop_gradients=None, unconnected_gradients=UnconnectedGradients.NONE, + functions = None, src_graph=None): """Implementation of gradients().""" if context.executing_eagerly(): @@ -536,7 +537,7 @@ def _GradientsHelper(ys, flat_grads = _GradientsHelper(flat_ys, flat_xs, flat_grad_ys, name, colocate_gradients_with_ops, gate_gradients, aggregation_method, stop_gradients, - unconnected_gradients, src_graph) + unconnected_gradients, src_graph = src_graph) return composite_tensor_gradient.replace_flat_tensors_for_gradients( xs, flat_grads) @@ -637,6 +638,9 @@ def _GradientsHelper(ys, is_partitioned_call = _IsPartitionedCall(op) # pylint: disable=protected-access is_func_call = src_graph._is_function(op.type) or is_partitioned_call + if not is_func_call and functions is not None: + is_func_call = op.type in functions + # pylint: enable=protected-access has_out_grads = any( isinstance(g, tensor_lib.Tensor) or g for g in out_grads @@ -665,6 +669,8 @@ def _GradientsHelper(ys, break else: func_call = src_graph._get_function(op.type) # pylint: disable=protected-access + if func_call is None and functions is not None: + func_call = functions.get(op.type,None) # Note that __defun is not set if the graph is # imported. If it's set, we prefer to access the original # defun. From 1e4e9d9278c42c41fd429a77a43777ce2073341e Mon Sep 17 00:00:00 2001 From: George Vasilakopoulos Date: Sun, 23 Jun 2024 18:16:46 +0000 Subject: [PATCH 52/53] debugging pending_count_ --- tensorflow/core/common_runtime/graph_constructor.cc | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tensorflow/core/common_runtime/graph_constructor.cc b/tensorflow/core/common_runtime/graph_constructor.cc index 1d2a7707c1cee3..49fd0f9c032bdd 100644 --- a/tensorflow/core/common_runtime/graph_constructor.cc +++ b/tensorflow/core/common_runtime/graph_constructor.cc @@ -839,7 +839,7 @@ Status GraphConstructor::InitFromEdges() { } } if (has_loop_back_edge) { - pending_count = num_control_edges + 1; + pending_count = std::min(num_control_edges + 1, node_def.input_size()); } } else if (IsReturningNode(node_def)) { int num_control_edges = 0; @@ -849,7 +849,7 @@ Status GraphConstructor::InitFromEdges() { num_control_edges++; } } - pending_count = num_control_edges + 1; + pending_count = std::min(num_control_edges + 1, node_def.input_size()); } for (int i = 0; i < node_def.input_size(); ++i) { StringPiece input_name = node_def.input(i); From 09615b79d64b709e3111d43b9688c7daccb957d4 Mon Sep 17 00:00:00 2001 From: George Vasilakopoulos Date: Sat, 29 Jun 2024 16:01:07 +0000 Subject: [PATCH 53/53] fixed type mismatch when func returns multiple outputs --- tensorflow/python/framework/function.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/tensorflow/python/framework/function.py b/tensorflow/python/framework/function.py index 8dc2629c738e11..68e4fc94c00673 100644 --- a/tensorflow/python/framework/function.py +++ b/tensorflow/python/framework/function.py @@ -1102,7 +1102,9 @@ def func_graph_from_py_func(func, gradient_out_types = [] if is_gradient: name = name + "Grad" - outputs = [func(*func_graph.inputs)] + outputs = func(*func_graph.inputs) + if not isinstance(outputs,list): + outputs = [outputs] dinputs = [] for (out, name) in list(zip(outputs, out_names)): argholder = array_ops.placeholder(out.op.node_def.attr["T"].type, name="d"+name)