diff --git a/configure b/configure index 87ef6e99b..22c36c9bf 100755 --- a/configure +++ b/configure @@ -63,6 +63,7 @@ if is_windows; then TF_NEED_HDFS=0 TF_NEED_JEMALLOC=0 TF_NEED_OPENCL=0 + TF_NEED_MPI=0 fi while [ "$TF_NEED_JEMALLOC" == "" ]; do @@ -112,6 +113,68 @@ else sed -i -e "s/WITH_GCP_SUPPORT = True/WITH_GCP_SUPPORT = False/" tensorflow/core/platform/default/build_config.bzl fi +while [ "$TF_NEED_MPI" == "" ]; do + read -p "Do you wish to build TensorFlow with "\ +"Message Passing Interface (MPI) support? [y/N] " INPUT + case $INPUT in + [Yy]* ) echo "MPI support will be enabled for "\ +"TensorFlow"; TF_NEED_MPI=1;; + [Nn]* ) echo "No MPI support will be enabled for "\ +"TensorFlow"; TF_NEED_MPI=0;; + "" ) echo "No Hadoop File System support will be enabled for "\ +"TensorFlow"; TF_NEED_MPI=0;; + * ) echo "Invalid selection: " $INPUT;; + esac +done + +while true; do + if [ "$TF_NEED_MPI" == "0" ]; then + break; + fi + + fromuser="" + if [ -z "$MPI_PATH" ]; then + default_mpi_path=$(dirname $(dirname $(which mpirun)) || dirname $(dirname $(which mpiexec)) || true) + read -p "Please specify the location of MPI. [Default is $default_mpi_path]: " MPI_PATH + fromuser="1" + if [ -z "$MPI_PATH" ]; then + MPI_PATH=$default_mpi_path + fi + fi + if [ -e "$MPI_PATH/include" -a -e "$MPI_PATH/lib" ]; then + break + fi + echo "Invalid MPI path. ${MPI_PATH}/include or ${MPI_PATH}/lib cannot be found" 1>&2 + if [ -z "$fromuser" ]; then + exit 1 + fi + MPI_PATH="" + # Retry +done + +if [ "$TF_NEED_MPI" == "1" ]; then + # Symlink necessary parts of MPI into the third party directory. + ln -sf "${MPI_PATH}/include/mpi.h" third_party/mpi/mpi.h + + if [ -e "${MPI_PATH}/include/mpi_portable_platform.h" ]; then + ln -sf "${MPI_PATH}/include/mpi_portable_platform.h" third_party/mpi/mpi_portable_platform.h + fi + + if [ -e "${MPI_PATH}/lib/libmpi.so" ]; then + ln -sf "${MPI_PATH}/lib/libmpi.so" third_party/mpi/libmpi.so + fi + + if [ -e "${MPI_PATH}/lib/libmpi.dylib" ]; then + ln -sf "${MPI_PATH}/lib/libmpi.dylib" third_party/mpi/libmpi.dylib + fi + + # Update Bazel build configuration. + sed -i -e "s/WITH_MPI_SUPPORT = False/WITH_MPI_SUPPORT = True/" tensorflow/contrib/BUILD +else + # Update Bazel build configuration. + sed -i -e "s/WITH_MPI_SUPPORT = True/WITH_MPI_SUPPORT = False/" tensorflow/contrib/BUILD +fi + while [ "$TF_NEED_HDFS" == "" ]; do read -p "Do you wish to build TensorFlow with "\ "Hadoop File System support? [y/N] " INPUT diff --git a/tensorflow/contrib/BUILD b/tensorflow/contrib/BUILD index 680053ae1..beadaf0ac 100644 --- a/tensorflow/contrib/BUILD +++ b/tensorflow/contrib/BUILD @@ -1,6 +1,9 @@ # Description: # contains parts of TensorFlow that are experimental or unstable and which are not supported. +# The configure script may change this to True. +WITH_MPI_SUPPORT = True + licenses(["notice"]) # Apache 2.0 exports_files(["LICENSE"]) @@ -56,7 +59,7 @@ py_library( "//tensorflow/contrib/tfprof", "//tensorflow/contrib/training:training_py", "//tensorflow/contrib/util:util_py", - ], + ] + ["//tensorflow/contrib/mpi:mpi_ops_py"] if WITH_MPI_SUPPORT else [], ) cc_library( diff --git a/tensorflow/contrib/mpi/BUILD b/tensorflow/contrib/mpi/BUILD new file mode 100644 index 000000000..b75773327 --- /dev/null +++ b/tensorflow/contrib/mpi/BUILD @@ -0,0 +1,66 @@ +# Ops that communicate with other processes via MPI. + +licenses(["notice"]) # Apache 2.0 + +package(default_visibility = ["//tensorflow:__subpackages__"]) + +load("//tensorflow:tensorflow.bzl", "tf_gen_op_wrapper_py") +load("//tensorflow:tensorflow.bzl", "tf_copts") +load("//tensorflow:tensorflow.bzl", "tf_custom_op_library") +load("//tensorflow:tensorflow.bzl", "tf_py_test") +load("//tensorflow/core:platform/default/build_config.bzl", "tf_proto_library") + +tf_custom_op_library( + name = "mpi.so", + srcs = ["mpi_ops.cc", "ring.cc", "ring.h"], + gpu_srcs = ["ring.cu.cc", "ring.h"], + deps = [ + "//third_party/mpi:mpi", + ":mpi_message_proto_cc", + ], +) + +tf_py_test( + name = "mpi_ops_test", + srcs = ["mpi_ops_test.py"], + additional_deps = [ + "//tensorflow:tensorflow_py", + "//tensorflow/python:platform", + ], + data = [ + ":mpi.so", + ], + tags = ["manual"], +) + +py_library( + name = "mpi_ops_py", + srcs = [ + "__init__.py", + "mpi_ops.py", + ], + data = [ + ":mpi.so", + ], + srcs_version = "PY2AND3", + visibility = ["//visibility:public"], +) + +filegroup( + name = "all_files", + srcs = glob( + ["**/*"], + exclude = [ + "**/METADATA", + "**/OWNERS", + ], + ), + visibility = ["//tensorflow:__subpackages__"], +) + +tf_proto_library( + name = "mpi_message_proto", + srcs = ["mpi_message.proto"], + cc_api_version = 2, + visibility = ["//visibility:public"], +) diff --git a/tensorflow/contrib/mpi/README.md b/tensorflow/contrib/mpi/README.md new file mode 100644 index 000000000..c5e1a8c37 --- /dev/null +++ b/tensorflow/contrib/mpi/README.md @@ -0,0 +1,5 @@ +# MPI TensorFlow integration + +Tensorflow MPI integration allows communicating between different TensorFlow +processes using MPI. This enables training across multiple nodes and GPUs +using high-speed interconnects. diff --git a/tensorflow/contrib/mpi/__init__.py b/tensorflow/contrib/mpi/__init__.py new file mode 100644 index 000000000..79a562367 --- /dev/null +++ b/tensorflow/contrib/mpi/__init__.py @@ -0,0 +1,281 @@ +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +# pylint: disable=g-short-docstring-punctuation +"""## Communicating Between Processes with MPI + +TensorFlow natively provides inter-device communication through send and +receive ops and inter-node communication through Distributed TensorFlow, based +on the same send and receive abstractions. On HPC clusters where Infiniband or +other high-speed node interconnects are available, these can end up being +insufficient for synchronous data-parallel training (without asynchronous +gradient descent). This module implements a variety of MPI ops which can take +advantage of hardware-specific MPI libraries for efficient communication. + +In order to use this module, TensorFlow must be built with an MPI library, +which can be provided to the `./configure` script at build time. As a user of +TensorFlow, you will need to build TensorFlow yourself to select the MPI +library to use; to do so, follow the [instructions for building TensorFlow from +source](https://www.tensorflow.org/get_started/os_setup#installing_from_sources). + +### Utility Ops + +In addition to reductions and gathers, this module provides utility operations +for detecting the running MPI configuration. + +Example: + +```python +from tensorflow.contrib import mpi + +# Use `mpi.Session` instead of `tf.Session` +with mpi.Session() as session: + rank = session.run(mpi.rank()) + print("My MPI Rank:", rank) + + if rank == 0: + print("MPI Size:", session.run(mpi.size())) +``` + +@@rank +@@size + +### Ring Allreduce and Allgather + +When summing or averaging tensors across many processes, communication can +easily become a bottleneck. A naive implementation will send all the tensor +values to the same process, perform the reduction, and then broadcast the +values back to all other processes, effectively creating a synchronous +parameter server in one process. However, the process responsible for +performing the reduction will have to receive and send a massive amount of data +which scales with the number of processes *and* the number of parameters in the +model. + +Instead of centralizing the reduction and having one primary reducer, we can +implement a distributed allreduce or allgather. A bandwidth-optimal allreduce +will end up sending 2(N - 1) values for every value in the input tensor, +and can be implemented with a ring allreduce [1]. (Intuitively, a linear reduce +requires at least (N - 1) sends between the different nodes, and a broadcast of +the result also requires (N - 1) sends, for a total of 2 (N - 1); these two +steps cannot be combined in a clever way to reduce the number of required +sends.) This module implements bandwidth-optimal ring allreduce and ring +allgather operations using MPI; by choosing a hardware-appropriate MPI +implementation (such as OpenMPI with CUDA-IPC support), you can train large +models with synchronous gradient descent with minimal communication overhead. + +In addition to the `allreduce` and `allgather` functions, a convenience +`DistributedOptimizer` wrapper is provided to simplify using these functions +for reducing model gradients. + +Example: + +```python +import tensorflow as tf +from tensorflow.contrib import mpi + +# Construct a simple linear regression model to optimize +W = tf.get_variable("W", shape=[20, 1], dtype=tf.float32) +B = tf.get_variable("B", shape=[1, 1], dtype=tf.float32) +inputs = tf.placeholder("Inputs", shape=[None, 20]) +outputs = tf.placeholder("Outputs", shape=[None, 1]) +loss = tf.nn.l2_loss(tf.matmul(inputs, W) + B - outputs) + +# Training using MPI allreduce with DistributedOptimizer +optimizer = mpi.DistributedOptimizer(tf.train.AdamOptimizer()) +train = optimizer.minimize(loss) + +# Average loss over all ranks, for printing. +# Do not pass this to an optimizer! +avg_loss = mpi.allreduce(loss) + +# On different ranks, feed different input data. +with mpi.Session() as session: + rank = session.run(mpi.rank()) + batch_inputs, batch_outputs = construct_batch_for_rank(rank) + feed_dict = {inputs: batch_inputs, outputs: batch_outputs} + _, l = session.run([train, avg_loss], feed_dict=feed_dict) + print("Average Loss:", l) +``` + +[1] Patarasuk, Pitch and Yuan, Xin. "Bandwidth Optimal All-reduce Algorithms +for Clusters of Workstations". + +@@Session +@@DistributedOptimizer +@@allreduce +@@allgather +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import tensorflow as tf + +from tensorflow.contrib.mpi.mpi_ops import size +from tensorflow.contrib.mpi.mpi_ops import rank +from tensorflow.contrib.mpi.mpi_ops import local_rank +from tensorflow.contrib.mpi.mpi_ops import allgather +from tensorflow.contrib.mpi.mpi_ops import _allreduce +from tensorflow.contrib.mpi.mpi_ops import init + + +def allreduce(tensor, average=True): + """Perform an MPI allreduce on a tf.Tensor or tf.IndexedSlices. + + Arguments: + tensor: tf.Tensor, tf.Variable, or tf.IndexedSlices to reduce. + The shape of the input must be identical across all ranks. + average: If True, computes the average over all ranks. + Otherwise, computes the sum over all ranks. + + This function performs a bandwidth-optimal ring allreduce on the input + tensor. If the input is an tf.IndexedSlices, the function instead does an + allgather on the values and the indices, effectively doing an allreduce on + the represented tensor. + """ + if isinstance(tensor, tf.IndexedSlices): + # For IndexedSlices, do two allgathers intead of an allreduce. + mpi_size = tf.cast(size(), tensor.values.dtype) + values = allgather(tensor.values) + indices = allgather(tensor.indices) + + # To make this operation into an average, divide all gathered values by + # the MPI size. + new_values = tf.div(values, mpi_size) if average else values + return tf.IndexedSlices(new_values, indices, + dense_shape=tensor.dense_shape) + else: + mpi_size = tf.cast(size(), tensor.dtype) + summed_tensor = _allreduce(tensor) + new_tensor = (tf.div(summed_tensor, mpi_size) + if average else summed_tensor) + return new_tensor + + +class DistributedOptimizer(tf.train.Optimizer): + """An optimizer that wraps another tf.Optimizer, using an MPI allreduce to + average gradient values before applying gradients to model weights.""" + + def __init__(self, optimizer, name=None, use_locking=False): + """Construct a new DistributedOptimizer, which uses another optimizer + under the hood for computing single-process gradient values and + applying gradient updates after the gradient values have been averaged + across all the MPI ranks. + + Args: + optimizer: + Optimizer to use for computing gradients and applying updates. + name: + Optional name prefix for the operations created when applying + gradients. Defaults to "Distributed" followed by the provided + optimizer type. + use_locking: + Whether to use locking when updating variables. + See Optimizer.__init__ for more info. + """ + if name is None: + name = "Distributed{}".format(type(optimizer).__name__) + + self._optimizer = optimizer + super(DistributedOptimizer, self).__init__( + name=name, use_locking=use_locking) + + def compute_gradients(self, *args, **kwargs): + """Compute gradients of all trainable variables. + + See Optimizer.compute_gradients() for more info. + + In DistributedOptimizer, compute_gradients() is overriden to also + allreduce the gradients before returning them. + """ + gradients = (super(DistributedOptimizer, self) + .compute_gradients(*args, **kwargs)) + return [(allreduce(gradient), var) for (gradient, var) in gradients] + + def _apply_dense(self, *args, **kwargs): + """Calls this same method on the underlying optimizer.""" + return self._optimizer._apply_dense(*args, **kwargs) + + def _apply_sparse(self, *args, **kwargs): + """Calls this same method on the underlying optimizer.""" + return self._optimizer._apply_sparse(*args, **kwargs) + + def _prepare(self, *args, **kwargs): + """Calls this same method on the underlying optimizer.""" + return self._optimizer._prepare(*args, **kwargs) + + def _create_slots(self, *args, **kwargs): + """Calls this same method on the underlying optimizer.""" + return self._optimizer._create_slots(*args, **kwargs) + + def _valid_dtypes(self, *args, **kwargs): + """Calls this same method on the underlying optimizer.""" + return self._optimizer._valid_dtypes(*args, **kwargs) + + def _finish(self, *args, **kwargs): + """Calls this same method on the underlying optimizer.""" + return self._optimizer._finish(*args, **kwargs) + + +class Session(tf.Session): + """A class for running TensorFlow operations, with copies of the same graph + running distributed across different MPI nodes. + + The primary difference between `tf.Session` and `tf.contrib.mpi.Session` is + that the MPI `Session` ensures that the `Session` options are correct for + use with `tf.contrib.mpi`, and initializes MPI immediately upon the start + of the session. + """ + + def __init__(self, gpu=None, target='', graph=None, config=None): + """Creates a new TensorFlow MPI session. + + Unlike a normal `tf.Session`, an MPI Session may only use a single GPU, + which must be specified in advance before the session is initialized. + In addition, it only uses a single graph evaluation thread, and + initializes MPI immediately upon starting. + + If no `graph` argument is specified when constructing the session, + the default graph will be launched in the session. If you are + using more than one graph (created with `tf.Graph()` in the same + process, you will have to use different sessions for each graph, + but each graph can be used in multiple sessions. In this case, it + is often clearer to pass the graph to be launched explicitly to + the session constructor. + + Args: + gpu: (Optional.) The GPU index to use, or None for CPU only MPI. + graph: (Optional.) The `Graph` to be launched (described above). + config: (Optional.) A `ConfigProto` protocol buffer with configuration + options for the session. + """ + if config is None: + config = tf.ConfigProto() + config.inter_op_parallelism_threads = 1 + elif config.inter_op_parallelism_threads != 1: + raise ValueError( + "inter_op_parallelism_threads must be 1 for MPI") + else: + config.inter_op_parallelism_threads = 1 + + if gpu is None: + config.gpu_options.visible_device_list = "" + else: + config.gpu_options.visible_device_list = str(gpu) + + super(Session, self).__init__(target, graph, config=config) + + # Initialize MPI on the relevant device. + self.run(init()) diff --git a/tensorflow/contrib/mpi/mpi_message.proto b/tensorflow/contrib/mpi/mpi_message.proto new file mode 100644 index 000000000..7c7866481 --- /dev/null +++ b/tensorflow/contrib/mpi/mpi_message.proto @@ -0,0 +1,66 @@ +syntax = "proto3"; + +package tensorflow.contrib.mpi; + +// We would like to just use DataType here, but since this +// is a contrib package, linking directly to TensorFlow protos seems to be +// impossible. Doing so compiles, but fails with a cryptic error at runtime +// about a pointer that was passed to free() but not created by malloc(). +// +// Since using the tensorflow/core protos seems to cause issues, we use our own, +// which also has the benefit of supporting only the data types we want to support. +enum MPIDataType { + TF_MPI_FLOAT32 = 0; + TF_MPI_INT32 = 1; +}; + +// An MPIRequest is a message sent from a rank greater than zero to the +// coordinator (rank zero), informing the coordinator of an operation that +// the rank wants to do and the tensor that it wants to apply the operation to. +message MPIRequest { + enum RequestType { + ALLREDUCE = 0; + ALLGATHER = 1; + } + + // The request rank is necessary to create a consistent ordering of results, + // for example in the allgather where the order of outputs should be sorted + // by rank. + int32 request_rank = 1; + RequestType request_type = 2; + MPIDataType tensor_type = 3; + string tensor_name = 4; + + // We use a repeated integer instead of a TensorShapeProto because linking directly + // to TensorFlow protos causes issues. See the comment for MPIDataType. + repeated int64 tensor_shape = 5; +}; + +// An MPIResponse is a message sent from the coordinator (rank zero) to a rank +// greater than zero, informing the rank of an operation should be performed +// now. If the operation requested would result in an error (for example, due +// to a type or shape mismatch), then the MPIResponse can contain an error and +// an error message instead. Finally, an MPIResponse can be a DONE message (if +// there are no more tensors to reduce on this tick of the background loop) or +// SHUTDOWN if all MPI processes should shut down. +message MPIResponse { + enum ResponseType { + ALLREDUCE = 0; + ALLGATHER = 1; + ERROR = 2; + DONE = 3; + SHUTDOWN = 4; + } + + // Empty if the type is DONE or SHUTDOWN. + ResponseType response_type = 1; + string tensor_name = 2; + + // Empty unless response_type is ERROR. + string error_message = 3; + + // Empty unless response_type is ALLGATHER. + // These tensor sizes are the dimension zero sizes of all the input matrices, + // indexed by the rank. + repeated int64 tensor_sizes = 4; +}; diff --git a/tensorflow/contrib/mpi/mpi_ops.cc b/tensorflow/contrib/mpi/mpi_ops.cc new file mode 100644 index 000000000..cc249821e --- /dev/null +++ b/tensorflow/contrib/mpi/mpi_ops.cc @@ -0,0 +1,1147 @@ +// Copyright 2016 The TensorFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= + +#include +#include +#include + +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/shape_inference.h" + +#define EIGEN_USE_THREADS + +#if GOOGLE_CUDA +#include "tensorflow/stream_executor/stream.h" +#include +#endif + +#include "tensorflow/stream_executor/lib/statusor.h" + + +#define OMPI_SKIP_MPICXX +#include "third_party/mpi/mpi.h" +#include "tensorflow/contrib/mpi/ring.h" +#include "tensorflow/contrib/mpi/mpi_message.pb.h" + +/* + * MPI Allreduce and Allgather Ops for TensorFlow. + * + * TensorFlow natively provides inter-device communication through send and + * receive ops and inter-node communication through Distributed TensorFlow, + * based on the same send and receive abstractions. These end up being + * insufficient for synchronous data-parallel training on HPC clusters where + * Infiniband or other high-speed interconnects are available. This module + * implements MPI ops for allgather and allreduce, which do bandwidth-optimal + * gathers and reductions and can take advantage of hardware-optimized + * communication libraries through the MPI implementation. + * + * The primary logic of the allreduce and allgather are in RingAllgather() and + * RingAllreduce(). The background thread which facilitates MPI operations is + * run in BackgroundThreadLoop(). The provided MPI ops are: + * – MPIInit: + * Initialize MPI on a given device (CPU or GPU). + * Should only be run on a single device in every process. + * – MPISize: + * Get the number of MPI processes in the global communicator. + * – MPIRank: + * Get the rank of the current MPI process in the global communicator. + * – MPILocalRank: + * Get the local rank of the current MPI process within its node. + * – MPIAllreduce: + * Perform an allreduce on a Tensor, returning the sum + * across all MPI processes in the global communicator. + * – MPIAllgather: + * Perform an allgather on a Tensor, returning the concatenation of + * the tensor on the first dimension across all MPI processes in the + * global communicator. + * + */ + +template +using StatusOr = perftools::gputools::port::StatusOr; + +using CPUDevice = Eigen::ThreadPoolDevice; +using GPUDevice = Eigen::GpuDevice; + +namespace tensorflow { +namespace contrib { +namespace mpi { + +// Make sure template specializations are generated in the ring.cu.cc and the +// ring.cc file, not in this file. +extern template Status RingAllreduce(OpKernelContext*, Tensor&, Tensor*); +extern template Status RingAllreduce(OpKernelContext*, Tensor&, Tensor*); +extern template Status RingAllgather( + OpKernelContext*, Tensor&, Tensor*, std::vector&); +extern template Status RingAllgather( + OpKernelContext*, Tensor&, Tensor*, std::vector&); +extern template Status RingAllreduce(OpKernelContext*, Tensor&, Tensor*); +extern template Status RingAllreduce(OpKernelContext*, Tensor&, Tensor*); +extern template Status RingAllgather( + OpKernelContext*, Tensor&, Tensor*, std::vector&); +extern template Status RingAllgather( + OpKernelContext*, Tensor&, Tensor*, std::vector&); + +namespace { + +// Return true if the templated type is GPUDevice, otherwise false. +template bool IsGPUDevice(); +template<> bool IsGPUDevice() { return true; }; +template<> bool IsGPUDevice() { return false; }; + +// A callback to call after the MPI communication completes. Since the +// allreduce and allgather ops are asynchronous, this callback is what resumes +// computation after the reduction is completed. +typedef std::function)> CommunicationDoneCallback; + +// Table storing Tensors to be reduced, keyed by unique name. +// This table contains everything necessary to do the reduction: +// - Tensor: The tensor data. +// - OpKernelContext*: A context used to allocate the output or temporary values. +// - CommunicationDoneCallback: A callback to call with the result. +typedef std::unordered_map > TensorTable; + +// Table for storing Tensor metadata on rank zero. This is used for error +// checking and size calculations, as well as determining when a reduction is +// ready to be done (when all nodes are ready to do it). +typedef std::unordered_map > MessageTable; + +// The global state required for the MPI ops. +// +// MPI is a library that stores a lot of global per-program state and often +// requires running on a single thread. As a result, we have to have a single +// background thread responsible for all MPI operations, and communicate with +// that background thread through global state. +struct MPIGlobalState { + // An atomic boolean which is set to true when MPI is initialized. + // This ensures that MPI_Init is never called twice. + std::atomic_flag initialized_flag = ATOMIC_FLAG_INIT; + + // A mutex that needs to be used whenever MPI operations are done. + std::mutex mutex; + + // Tensors waiting to be allreduced or allgathered. + TensorTable tensor_table; + + // Queue of MPI requests waiting to be sent to the coordinator node. + std::queue message_queue; + + // Background thread running MPI communication. + std::thread background_thread; + + // Whether the background thread should shutdown. + bool shut_down = false; + + // Only exists on the coordinator node (rank zero). Maintains a count of + // how many nodes are ready to allreduce every tensor (keyed by tensor + // name). + std::unique_ptr message_table; + + // Whether MPI_Init has been completed on the background thread. + bool initialization_done = false; + + // The device that MPI was initialized on. (-1 for no GPU) + int device = -1; + + // Whether MPI_Init succeeded on the background thread. + Status init_status; + + // The MPI rank, local rank, and size. + int rank = 0; + int local_rank = 0; + int size = 1; + + // The CUDA stream used for data transfers and within-allreduce operations. + // A naive implementation would use the TensorFlow StreamExecutor CUDA + // stream. However, the allreduce and allgather require doing memory copies + // and kernel executions (for accumulation of values on the GPU). However, + // the subsequent operations must wait for those operations to complete, + // otherwise MPI (which uses its own stream internally) will begin the data + // transfers before the CUDA calls are complete. In order to wait for those + // CUDA operations, if we were using the TensorFlow stream, we would have to + // synchronize that stream; however, other TensorFlow threads may be + // submitting more work to that stream, so synchronizing on it can cause the + // allreduce to be delayed, waiting for compute totally unrelated to it in + // other parts of the graph. Overlaying memory transfers and compute during + // backpropagation is crucial for good performance, so we cannot use the + // TensorFlow stream, and must use our own stream. +#if GOOGLE_CUDA + cudaStream_t stream; + std::atomic_flag stream_created_flag = ATOMIC_FLAG_INIT; +#endif + + ~MPIGlobalState() { + // Make sure that the destructor of the background thread is safe to + // call. If a thread is still joinable (not detached or complete) its + // destructor cannot be called. + if(background_thread.joinable()) { + shut_down = true; + background_thread.join(); + } + } +}; + +// All the MPI state that must be stored globally per-process. +static MPIGlobalState mpi_global; + +// For clarify in argument lists. +#define RANK_ZERO 0 + +// A tag used for all coordinator messaging. +#define TAG_NOTIFY 1 + +// Store the MPIRequest for a name, and return whether the total count of +// MPIRequests for that tensor is now equal to the MPI size (and thus we are +// ready to reduce the tensor). +bool IncrementTensorCount( + std::unique_ptr& message_table, + MPIRequest msg, int mpi_size) { + auto name = msg.tensor_name(); + auto table_iter = message_table->find(name); + if(table_iter == message_table->end()) { + message_table->emplace(name, std::vector({msg})); + table_iter = message_table->find(name); + } else { + table_iter->second.push_back(msg); + } + + int count = table_iter->second.size(); + return count == mpi_size; +} + +// Once a tensor is ready to be reduced, the coordinator sends an MPIResponse +// instructing all ranks to start the reduction to all ranks. The MPIResponse +// also contains error messages in case the submitted MPIRequests were not +// valid (for example, contained mismatched shapes or types). +// +// Constructing the MPIResponse, thus, requires a whole lot of error checking. +MPIResponse ConstructMPIResponse(std::unique_ptr& message_table, std::string name) { + bool error = false; + auto it = message_table->find(name); + assert(it != message_table->end()); + + std::vector requests = it->second; + assert(requests.size() > 0); + + std::ostringstream error_message_stream; + + // Check that all data types being reduced or gathered are identical + auto data_type = requests[0].tensor_type(); + for(unsigned int i = 1; i < requests.size(); i++) { + auto request_type = requests[i].tensor_type(); + if(data_type != request_type) { + error = true; + error_message_stream + << "Mismatched data types: One rank had type " + << MPIDataType_Name(data_type) + << ", but another rank had type " + << MPIDataType_Name(request_type) + << "."; + break; + } + } + + // Check that all requested operations are the same + auto message_type = requests[0].request_type(); + for(unsigned int i = 1; i < requests.size(); i++) { + if(error) { + break; + } + + auto request_type = requests[i].request_type(); + if(message_type != request_type) { + error = true; + error_message_stream + << "Mismatched MPI operations: One rank did an " + << message_type + << ", but another rank did an " + << request_type + << "."; + break; + } + } + + // If we are doing an allreduce, check that all tensor shapes are identical + if(message_type == MPIRequest::ALLREDUCE) { + TensorShape tensor_shape; + for(auto it = requests[0].tensor_shape().begin(); + it != requests[0].tensor_shape().end(); it++) { + tensor_shape.AddDim(*it); + } + for(unsigned int i = 1; i < requests.size(); i++) { + if(error) { + break; + } + + TensorShape request_shape; + for(auto it = requests[i].tensor_shape().begin(); + it != requests[i].tensor_shape().end(); it++) { + request_shape.AddDim(*it); + } + if(tensor_shape != request_shape) { + error = true; + error_message_stream + << "Mismatched allreduce tensor shapes: " + << "One rank reduced a tensor of shape " + << tensor_shape.DebugString() + << ", but another rank sent a tensor of shape " + << request_shape.DebugString() + << "."; + break; + } + } + } + + // If we are doing an allgather, make sure all but the first dimension are + // the same. The first dimension may be different and the output tensor is + // the sum of the first dimension. Collect the sizes by rank. + std::vector tensor_sizes(requests.size()); + if(message_type == MPIRequest::ALLGATHER) { + TensorShape tensor_shape; + for(auto it = requests[0].tensor_shape().begin(); + it != requests[0].tensor_shape().end(); it++) { + tensor_shape.AddDim(*it); + } + + if(tensor_shape.dims() == 0) { + error = true; + error_message_stream + << "Rank zero tried to gather a rank-zero tensor."; + } else { + tensor_sizes[requests[0].request_rank()] = size_t(tensor_shape.dim_size(0)); + } + + for(unsigned int i = 1; i < requests.size(); i++) { + if(error) { + break; + } + + TensorShape request_shape; + for(auto it = requests[i].tensor_shape().begin(); + it != requests[i].tensor_shape().end(); it++) { + request_shape.AddDim(*it); + } + if(tensor_shape.dims() != request_shape.dims()) { + error = true; + error_message_stream + << "Mismatched allgather tensor shapes: " + << "One rank gathered a tensor of rank " + << tensor_shape.dims() + << ", but another rank sent a tensor of rank " + << request_shape.dims() + << "."; + break; + } + + bool dim_mismatch = false; + for(unsigned int dim = 1; dim < tensor_shape.dims(); dim++) { + if(tensor_shape.dim_size(dim) != request_shape.dim_size(dim)) { + error = true; + error_message_stream + << "Mismatched allgather tensor shapes: " + << "One rank gathered a tensor with dimension " + << dim << " equal to " << tensor_shape.dim_size(dim) + << ", but another rank sent a tensor with dimension " + << dim << " equal to " << request_shape.dim_size(dim) + << "."; + dim_mismatch = true; + break; + } + } + if(dim_mismatch) { + break; + } + + tensor_sizes[requests[i].request_rank()] = size_t(request_shape.dim_size(0)); + } + } + + MPIResponse response; + response.set_tensor_name(name); + if(error) { + std::string error_message = error_message_stream.str(); + response.set_response_type(MPIResponse::ERROR); + response.set_error_message(error_message); + } else if(message_type == MPIRequest::ALLGATHER) { + response.set_response_type(MPIResponse::ALLGATHER); + for(auto dim : tensor_sizes) { + response.add_tensor_sizes(dim); + } + } else if(message_type == MPIRequest::ALLREDUCE) { + response.set_response_type(MPIResponse::ALLREDUCE); + } + + // Clear all queued up requests for this name. They are now taken care of + // by the constructed MPI response. + message_table->erase(it); + + return response; +} + +// Process an MPIResponse by doing a reduction, a gather, or raising an error. +void PerformReductionOrGather(TensorTable& tensor_table, MPIResponse response) { + Tensor tensor; + OpKernelContext* context; + CommunicationDoneCallback callback; + bool on_gpu; + { + // Lock on the tensor table. + std::lock_guard guard(mpi_global.mutex); + + // We should never fail at finding this key in the tensor table. + auto name = response.tensor_name(); + auto iter = tensor_table.find(name); + assert(iter != tensor_table.end()); + + assert(response.response_type() == MPIResponse::ALLREDUCE || + response.response_type() == MPIResponse::ALLGATHER || + response.response_type() == MPIResponse::ERROR); + + std::tie(tensor, context, on_gpu, callback) = iter->second; + + // Clear the tensor table of this tensor and its callbacks; the rest of + // this function takes care of it. + tensor_table.erase(iter); + } + + // Use CPUDevice instead of GPUDevice if no CUDA, to ensure we don't + // link to non-existent symbols. +#if GOOGLE_CUDA +#define GPU_DEVICE_IF_CUDA GPUDevice +#else +#define GPU_DEVICE_IF_CUDA CPUDevice +#endif + + Tensor output; + Status status; + if(response.response_type() == MPIResponse::ALLGATHER) { + // Copy tensor sizes from the MPI response into a vector of size_t + std::vector tensor_sizes; + for(auto it = response.tensor_sizes().begin(); + it != response.tensor_sizes().end(); it++) { + tensor_sizes.push_back(size_t(*it)); + } + + if(tensor.dtype() == DT_FLOAT) { + status = on_gpu ? RingAllgather(context, tensor, &output, tensor_sizes) + : RingAllgather(context, tensor, &output, tensor_sizes); + } else if(tensor.dtype() == DT_INT32) { + status = on_gpu ? RingAllgather(context, tensor, &output, tensor_sizes) + : RingAllgather(context, tensor, &output, tensor_sizes); + } else { + status = errors::Unknown("Invalid tensor type for MPI allgather."); + } + } else if(response.response_type() == MPIResponse::ALLREDUCE) { + if(tensor.dtype() == DT_FLOAT) { + status = on_gpu ? RingAllreduce(context, tensor, &output) + : RingAllreduce(context, tensor, &output); + } else if(tensor.dtype() == DT_INT32) { + status = on_gpu ? RingAllreduce(context, tensor, &output) + : RingAllreduce(context, tensor, &output); + } else { + status = errors::Unknown("Invalid tensor type for MPI allreduce."); + } + } else if(response.response_type() == MPIResponse::ERROR) { + status = errors::FailedPrecondition(response.error_message()); + } + + if(status.ok()) { + callback(StatusOr(output)); + } else { + callback(StatusOr(status)); + } +} + +// The MPI background thread loop coordinates all the MPI processes and the +// tensor reductions. The design of the communicator mechanism is limited by a few considerations: +// +// 1. Some MPI implementations require all MPI calls to happen from a single thread. +// Since TensorFlow may use several threads for graph processing, this means we must have +// our own dedicated thread for dealing with MPI. +// 2. We want to gracefully handle errors, when MPI processes do not properly agree upon +// what should happen (such as mismatched types or shapes). To do so requires the MPI processes +// to know about the shapes and types of the relevant tensors on the other processes. +// 3. The MPI reductions and gathers should be able to happen in parallel +// with other ongoing operations. This means that they cannot be blocking +// ops, but rather must be async ops, the execution of which happens on a +// separate thread. +// 4. We cannot guarantee that all the MPI processes reduce their tensors +// in the same order, so we cannot dispatch one thread per tensor, +// otherwise we may end up dispatching many blocked threads and never make +// progress if we have a thread pool limit. +// +// The coordinator currently follows a master-worker paradigm. Rank zero acts +// as the master (the "coordinator"), whereas all other ranks are simply +// workers. Each rank runs its own background thread which progresses in ticks. +// In each tick, the following actions happen: +// +// a) The workers send an MPIRequest to the coordinator, indicating what +// they would like to do (which tensor they would like to gather and +// reduce, as well as their shape and type). They repeat this for every +// tensor that they would like to operate on. +// +// b) The workers send an empty "DONE" message to the coordinator to +// indicate that there are no more tensors they wish to operate on. +// +// c) The coordinator receives the MPIRequests from the workers, as well +// as from its own TensorFlow ops, and stores them in a request table. The +// coordinator continues to receive MPIRequest messages until it has +// received MPI_SIZE number of empty "DONE" messages. +// +// d) The coordinator finds all tensors that are ready to be reduced, +// gathered, or all operations that result in an error. For each of those, +// it sends an MPIResponse to all the workers. When no more MPIResponses +// are available, it sends a "DONE" response to the workers. If the process +// is being shutdown, it instead sends a "SHUTDOWN" response. +// +// e) The workers listen for MPIResponse messages, processing each one by +// doing the required reduce or gather, until they receive a "DONE" +// response from the coordinator. At that point, the tick ends. +// If instead of "DONE" they receive "SHUTDOWN", they exit their background loop. +void BackgroundThreadLoop(MPIGlobalState& state) { +#if GOOGLE_CUDA + // Set the device, so that this thread uses the same GPU context as the + // calling thread. + if(state.device > 0) { + cudaSetDevice(state.device); + cudaStreamCreate(&state.stream); + } +#endif + + // Initialize MPI. This must happen on the background thread, since not all + // MPI implementations support being called from multiple threads. + auto init_result = MPI_Init(NULL, NULL); + if(init_result != MPI_SUCCESS) { + state.init_status = errors::Unknown("Could not initialize MPI; MPI_Init() failed."); + } else { + state.init_status = Status::OK(); + } + + // Get MPI rank to determine if we are rank zero. + int rank; + MPI_Comm_rank(MPI_COMM_WORLD, &rank); + bool is_coordinator = rank == 0; + + // Get MPI size to determine how many tensors to wait for before reducing. + int size; + MPI_Comm_size(MPI_COMM_WORLD, &size); + + // Determine local rank by querying the local communicator. + MPI_Comm local_comm; + MPI_Comm_split_type(MPI_COMM_WORLD, MPI_COMM_TYPE_SHARED, 0, + MPI_INFO_NULL, &local_comm); + int local_rank; + MPI_Comm_rank(local_comm, &local_rank); + + state.rank = rank; + state.local_rank = local_rank; + state.size = size; + state.initialization_done = true; + + // Initialize the tensor count table. No tensors are available yet. + if(is_coordinator) { + state.message_table = + std::unique_ptr(new MessageTable()); + } + + // The coordinator sends a SHUTDOWN message to trigger shutdown. + bool should_shut_down = false; + do { + // This delay determines thread frequency and MPI message latency + std::this_thread::sleep_for(std::chrono::milliseconds(5)); + + // Copy the data structures from global state under this lock. + // However, don't keep the lock for the rest of the loop, so that + // enqueued stream callbacks can continue. + std::queue message_queue; + { + std::lock_guard guard(state.mutex); + while(!state.message_queue.empty()) { + MPIRequest message = state.message_queue.front(); + state.message_queue.pop(); + message_queue.push(message); + } + } + + // Collect all tensors that are ready to be reduced. Record them in the + // tensor count table (rank zero) or send them to rank zero to be + // recorded (everyone else). + std::vector ready_to_reduce; + while(!message_queue.empty()) { + // Pop the first available message message + MPIRequest message = message_queue.front(); + message_queue.pop(); + + if(is_coordinator) { + bool reduce = IncrementTensorCount(state.message_table, message, size); + if(reduce) { + ready_to_reduce.push_back(message.tensor_name()); + } + } else { + std::string encoded_message; + message.SerializeToString(&encoded_message); + MPI_Send(encoded_message.c_str(), encoded_message.length() + 1, + MPI_BYTE, RANK_ZERO, TAG_NOTIFY, MPI_COMM_WORLD); + } + } + + // Rank zero has put all its own tensors in the tensor count table. + // Now, it should count all the tensors that are coming from other + // ranks at this tick. It should keep getting tensors until it gets a + // DONE message from all the other ranks. + if(is_coordinator) { + // Count of DONE messages. Keep receiving messages until the number + // of messages is equal to the number of processes. Initialize to + // one since the coordinator is effectively done. + int completed_ranks = 1; + while(completed_ranks != size) { + MPI_Status status; + MPI_Probe(MPI_ANY_SOURCE, TAG_NOTIFY, MPI_COMM_WORLD, &status); + + // Find number of characters in message (including zero byte). + int source_rank = status.MPI_SOURCE; + int msg_length; + MPI_Get_count(&status, MPI_BYTE, &msg_length); + + // If the length is zero, this is a DONE message. + if(msg_length == 0) { + completed_ranks++; + MPI_Recv(NULL, 0, MPI_BYTE, source_rank, TAG_NOTIFY, MPI_COMM_WORLD, &status); + continue; + } + + // Get tensor name from MPI into an std::string. + char* buffer = new char[msg_length]; + MPI_Recv(buffer, msg_length, MPI_BYTE, source_rank, + TAG_NOTIFY, MPI_COMM_WORLD, &status); + std::string received_data(buffer); + delete[] buffer; + + MPIRequest received_message; + received_message.ParseFromString(received_data); + auto received_name = received_message.tensor_name(); + + bool reduce = IncrementTensorCount( + state.message_table, received_message, size); + if(reduce) { + ready_to_reduce.push_back(received_name); + } + } + + // At this point, rank zero should have a fully updated tensor count + // table and should know all the tensors that need to be reduced or + // gathered, and everyone else should have sent all their information + // to rank zero. We can now do reductions and gathers; rank zero will + // choose which ones and in what order, and will notify the other ranks + // before doing each reduction. + for(int i = 0; i < ready_to_reduce.size(); i++) { + // Notify all nodes which tensor we'd like to reduce at this step. + auto name = ready_to_reduce[i]; + MPIResponse response = ConstructMPIResponse(state.message_table, name); + + std::string encoded_response; + response.SerializeToString(&encoded_response); + for(int r = 1; r < size; r++) { + MPI_Send(encoded_response.c_str(), encoded_response.length() + 1, + MPI_BYTE, r, TAG_NOTIFY, MPI_COMM_WORLD); + } + + // Perform the reduction. All nodes should end up performing the same reduction. + PerformReductionOrGather(state.tensor_table, response); + } + + // Notify all nodes that we are done with the reductions for this tick. + MPIResponse done_response; + should_shut_down = state.shut_down; + done_response.set_response_type( + should_shut_down ? MPIResponse::SHUTDOWN : MPIResponse::DONE); + std::string encoded_response; + done_response.SerializeToString(&encoded_response); + for(int r = 1; r < size; r++) { + MPI_Send(encoded_response.c_str(), encoded_response.length() + 1, + MPI_BYTE, r, TAG_NOTIFY, MPI_COMM_WORLD); + } + } else { + // Notify the coordinator that this node is done sending messages. + // A DONE message is encoded as a zero-length message. + MPI_Send(NULL, 0, MPI_BYTE, RANK_ZERO, TAG_NOTIFY, MPI_COMM_WORLD); + + // Receive names for tensors to reduce from rank zero. + // Once we receive a empty DONE message, stop waiting for more names. + while(true) { + MPI_Status status; + MPI_Probe(0, TAG_NOTIFY, MPI_COMM_WORLD, &status); + + // Find number of characters in message (including zero byte). + int msg_length; + MPI_Get_count(&status, MPI_BYTE, &msg_length); + + // Get tensor name from MPI into an std::string. + char* buffer = new char[msg_length]; + MPI_Recv(buffer, msg_length, MPI_BYTE, 0, + TAG_NOTIFY, MPI_COMM_WORLD, &status); + std::string received_message(buffer); + delete[] buffer; + + MPIResponse response; + response.ParseFromString(received_message); + if(response.response_type() == MPIResponse::DONE) { + // No more messages this tick + break; + } else if(response.response_type() == MPIResponse::SHUTDOWN) { + // No more messages this tick, and the background thread should shut down + should_shut_down = true; + break; + } else { + // Process the current message + PerformReductionOrGather(state.tensor_table, response); + } + } + } + } while(!should_shut_down); + + MPI_Finalize(); +} + +// Initialize MPI and start the MPI background thread. Ensure that this is +// only done once no matter how many times this function is called. +Status InitializeMPIOnce(bool gpu) { + // Ensure MPI is only initialized once. + if(mpi_global.initialized_flag.test_and_set()) + return mpi_global.init_status; + + int current_device = -1; +#if GOOGLE_CUDA + if(gpu) { + cudaGetDevice(¤t_device); + } +#endif + mpi_global.device = current_device; + + // Start the MPI background thread, which assumes MPI is initialized + mpi_global.background_thread = std::thread(BackgroundThreadLoop, + std::ref(mpi_global)); + + // Wait to ensure that the background thread has finished initializing MPI. + while(!mpi_global.initialization_done) { + std::this_thread::sleep_for(std::chrono::milliseconds(1)); + } + + return mpi_global.init_status; +} + +// Check that MPI is initialized and is initialized on the same device. +Status InitializedMPIOnSameDevice(bool gpu) { + if(!mpi_global.initialization_done) { + return errors::FailedPrecondition("MPI has not been initialized; use tf.contrib.mpi.Session."); + } + if(gpu && mpi_global.device < 0) { + return errors::FailedPrecondition("MPI op on GPU, but initialized on CPU."); + } + +#if GOOGLE_CUDA + if(gpu) { + int current_device = -1; + cudaGetDevice(¤t_device); + if(current_device != mpi_global.device) { + return errors::FailedPrecondition("MPI op on different GPU than MPI was initialized on."); + } + } +#endif + + return Status::OK(); +} + +// Convert a TensorFlow DataType to our MPIDataType. +Status DataTypeToMPIType(DataType tf_dtype, MPIDataType* mpi_dtype) { + if(tf_dtype == DT_FLOAT) { + *mpi_dtype = TF_MPI_FLOAT32; + } else if(tf_dtype == DT_INT32) { + *mpi_dtype = TF_MPI_INT32; + } else { + return errors::FailedPrecondition("Invalid tensor type passed."); + } + return Status::OK(); +} + +// MPI must be initialized and the background thread must be running before +// this function is called. +void EnqueueTensorAllreduce( + OpKernelContext* context, + const Tensor& tensor, + const std::string name, + const bool on_gpu, + CommunicationDoneCallback callback) { + MPIDataType dtype; + Status status = DataTypeToMPIType(tensor.dtype(), &dtype); + if(!status.ok()) { + callback(StatusOr(status)); + return; + } + + int rank; + MPI_Comm_rank(MPI_COMM_WORLD, &rank); + + MPIRequest message; + message.set_request_rank(rank); + message.set_tensor_name(name); + message.set_tensor_type(dtype); + message.set_request_type(MPIRequest::ALLREDUCE); + for(int i = 0; i < tensor.shape().dims(); i++) { + message.add_tensor_shape(tensor.shape().dim_size(i)); + } + + std::lock_guard guard(mpi_global.mutex); + std::tuple record(tensor, context, on_gpu, callback); + mpi_global.tensor_table.emplace(name, record); + mpi_global.message_queue.push(message); +} + +// MPI must be initialized and the background thread must be running before +// this function is called. +void EnqueueTensorAllgather( + OpKernelContext* context, + const Tensor& tensor, + const std::string name, + const bool on_gpu, + CommunicationDoneCallback callback) { + MPIDataType dtype; + Status status = DataTypeToMPIType(tensor.dtype(), &dtype); + if(!status.ok()) { + callback(StatusOr(status)); + return; + } + + int rank; + MPI_Comm_rank(MPI_COMM_WORLD, &rank); + + MPIRequest message; + message.set_request_rank(rank); + message.set_tensor_name(name); + message.set_tensor_type(dtype); + message.set_request_type(MPIRequest::ALLGATHER); + for(int i = 0; i < tensor.shape().dims(); i++) { + message.add_tensor_shape(tensor.shape().dim_size(i)); + } + + std::lock_guard guard(mpi_global.mutex); + std::tuple record(tensor, context, on_gpu, callback); + mpi_global.tensor_table.emplace(name, record); + mpi_global.message_queue.push(message); +} +} + +#if GOOGLE_CUDA +cudaStream_t CudaStreamForMPI() { + return mpi_global.stream; +} +#endif + +// Op to initialize MPI in the current process. The settings used in the +// configuration are the same that must be used for all future MPI ops. +template +class MPIInitOp : public OpKernel { + public: + explicit MPIInitOp(OpKernelConstruction* context) : OpKernel(context) { + } + + + void Compute(OpKernelContext* context) override { + bool on_gpu = IsGPUDevice(); + OP_REQUIRES_OK(context, InitializeMPIOnce(on_gpu)); + } +}; + +REGISTER_KERNEL_BUILDER(Name("MPIInit").Device(DEVICE_CPU), MPIInitOp); +#if GOOGLE_CUDA +REGISTER_KERNEL_BUILDER(Name("MPIInit").Device(DEVICE_GPU), MPIInitOp); +#endif + +REGISTER_OP("MPIInit") + .Doc(R"doc( +Initialize MPI for the current process. + +If this is run on a GPU, then that GPU must be used for all future MPI +operations. If it is run on CPU, then all future MPI operations must also run on +CPU. +)doc"); + +// Op to get the current MPI Size. +template +class MPISizeOp : public OpKernel { + public: + explicit MPISizeOp(OpKernelConstruction* context) : OpKernel(context) { } + + + void Compute(OpKernelContext* context) override { + bool on_gpu = IsGPUDevice(); + OP_REQUIRES_OK(context, InitializedMPIOnSameDevice(on_gpu)); + + // Get the number of processes + int world_size = mpi_global.size; + + // Write integer to output tensor + Tensor* output; + OP_REQUIRES_OK(context, context->allocate_output(0, TensorShape({}), &output)); + + auto flat = output->flat(); + flat(0) = world_size; + } +}; + +REGISTER_KERNEL_BUILDER(Name("MPISize").Device(DEVICE_CPU), MPISizeOp); +#if GOOGLE_CUDA +REGISTER_KERNEL_BUILDER(Name("MPISize").Device(DEVICE_GPU).HostMemory("size"), MPISizeOp); +#endif + +REGISTER_OP("MPISize") + .Output("size: int32") + .SetShapeFn([](shape_inference::InferenceContext* c) { + c->set_output(0, c->Scalar()); + return Status::OK(); + }) + .Doc(R"doc( +Returns the number of running MPI processes. + +More precisely, returns the number of MPI processes in the group associated +with the MPI_COMM_WORLD communicator. + +size: Size of the MPI group. +)doc"); + +// Op to get the current MPI Rank. +template +class MPIRankOp : public OpKernel { + public: + explicit MPIRankOp(OpKernelConstruction* context) : OpKernel(context) { } + + void Compute(OpKernelContext* context) override { + bool on_gpu = IsGPUDevice(); + OP_REQUIRES_OK(context, InitializedMPIOnSameDevice(on_gpu)); + + // Get the processor index + int rank = mpi_global.rank; + + // Write integer to output tensor + Tensor* output; + OP_REQUIRES_OK(context, context->allocate_output(0, TensorShape({}), &output)); + + auto flat = output->flat(); + flat(0) = rank; + } +}; + +REGISTER_KERNEL_BUILDER(Name("MPIRank").Device(DEVICE_CPU), MPIRankOp); +#if GOOGLE_CUDA +REGISTER_KERNEL_BUILDER(Name("MPIRank").Device(DEVICE_GPU).HostMemory("rank"), MPIRankOp); +#endif + +REGISTER_OP("MPIRank") + .Output("rank: int32") + .SetShapeFn([](shape_inference::InferenceContext* c) { + c->set_output(0, c->Scalar()); + return Status::OK(); + }) + .Doc(R"doc( +Returns the index of the current process in the MPI group. + +More precisely, returns the rank of the calling process in the MPI_COMM_WORLD +communicator. + +rank: Rank of the calling process. +)doc"); + + +// Op to get the current local MPI Rank. +template +class MPILocalRankOp : public OpKernel { + public: + explicit MPILocalRankOp(OpKernelConstruction* context) : OpKernel(context) { } + + void Compute(OpKernelContext* context) override { + bool on_gpu = IsGPUDevice(); + OP_REQUIRES_OK(context, InitializedMPIOnSameDevice(on_gpu)); + + int local_rank = mpi_global.local_rank; + + // Write integer to output tensor + Tensor* output; + OP_REQUIRES_OK(context, context->allocate_output(0, TensorShape({}), &output)); + + auto flat = output->flat(); + flat(0) = local_rank; + } +}; + +REGISTER_KERNEL_BUILDER(Name("MPILocalRank").Device(DEVICE_CPU), MPILocalRankOp); +#if GOOGLE_CUDA +REGISTER_KERNEL_BUILDER(Name("MPILocalRank").Device(DEVICE_GPU).HostMemory("rank"), MPILocalRankOp); +#endif + +REGISTER_OP("MPILocalRank") + .Output("rank: int32") + .SetShapeFn([](shape_inference::InferenceContext* c) { + c->set_output(0, c->Scalar()); + return Status::OK(); + }) + .Doc(R"doc( +Returns the index of the current process in the node it is on. + +More precisely, returns the rank of the calling process in communicator that +only spans the MPI processes running on that node. + +rank: Rank of the calling process on the node it is on. +)doc"); + +template +class MPIAllreduceOp : public AsyncOpKernel { + public: + explicit MPIAllreduceOp(OpKernelConstruction* context) : AsyncOpKernel(context) { } + + void ComputeAsync(OpKernelContext* context, DoneCallback done) override { + bool on_gpu = IsGPUDevice(); + OP_REQUIRES_OK(context, InitializedMPIOnSameDevice(on_gpu)); + + auto device_context = context->op_device_context(); + auto node_name = name(); + auto callback = [node_name, done, context, on_gpu] { + auto tensor = context->input(0); + EnqueueTensorAllreduce(context, tensor, node_name, on_gpu, + [node_name, done, context](StatusOr status) { + if(status.ok()) { + Tensor output = status.ValueOrDie(); + context->set_output(0, output); + } + context->SetStatus(status.status()); + done(); + }); + }; + + // If we are on a CPU, our device context will be null and we can't + // get a stream to enqueue this on. On a CPU this op is called when the + // data is already available, so we can just immediately do the allreduce; + // we don't have to wait for the data to get populated. +#if GOOGLE_CUDA + if(device_context == nullptr) { + callback(); + } else { + auto stream = device_context->stream(); + stream->ThenDoHostCallback(callback); + } +#else + callback(); +#endif + } +}; + +REGISTER_KERNEL_BUILDER(Name("MPIAllreduce").Device(DEVICE_CPU), MPIAllreduceOp); +#if GOOGLE_CUDA +REGISTER_KERNEL_BUILDER(Name("MPIAllreduce").Device(DEVICE_GPU), MPIAllreduceOp); +#endif + +REGISTER_OP("MPIAllreduce") + .Attr("T: {int32, float32}") + .Input("tensor: T") + .Output("sum: T") + .SetShapeFn([](shape_inference::InferenceContext* c) { + c->set_output(0, c->input(0)); + return Status::OK(); + }) + .Doc(R"doc( +Perform an MPI Allreduce on a tensor. All other processes that do a reduction +on a tensor with the same name must have the same dimension for that tensor. +Tensors are reduced with other tensors that have the same node name for the +allreduce. + +Arguments + tensor: A tensor to reduce. + +Output + sum: A tensor with the same shape as `tensor`, summed across all MPI processes. +)doc"); + +template +class MPIAllgatherOp : public AsyncOpKernel { + public: + explicit MPIAllgatherOp(OpKernelConstruction* context) : AsyncOpKernel(context) { + } + + void ComputeAsync(OpKernelContext* context, DoneCallback done) override { + bool on_gpu = IsGPUDevice(); + OP_REQUIRES_OK(context, InitializedMPIOnSameDevice(on_gpu)); + + auto device_context = context->op_device_context(); + auto node_name = name(); + auto callback = [node_name, done, context, on_gpu] { + auto tensor = context->input(0); + EnqueueTensorAllgather(context, tensor, node_name, on_gpu, + [node_name, done, context](StatusOr status) { + if(status.ok()) { + Tensor output = status.ValueOrDie(); + context->set_output(0, output); + } + context->SetStatus(status.status()); + done(); + }); + }; + + // If we are on a CPU, our device context will be null and we can't + // get a stream to enqueue this on. On a CPU this op is called when the + // data is already available, so we can just immediately do the allgather; + // we don't have to wait for the data to get populated. +#if GOOGLE_CUDA + if(device_context == nullptr) { + callback(); + } else { + auto stream = device_context->stream(); + stream->ThenDoHostCallback(callback); + } +#else + callback(); +#endif + } +}; + +REGISTER_KERNEL_BUILDER(Name("MPIAllgather").Device(DEVICE_CPU), MPIAllgatherOp); +#if GOOGLE_CUDA +REGISTER_KERNEL_BUILDER(Name("MPIAllgather").Device(DEVICE_GPU), MPIAllgatherOp); +#endif + +REGISTER_OP("MPIAllgather") + .Attr("T: {int32, float32}") + .Input("tensor: T") + .Output("output: T") + .SetShapeFn([](shape_inference::InferenceContext* c) { + shape_inference::ShapeHandle output; + TF_RETURN_IF_ERROR(c->ReplaceDim(c->input(0), 0, c->UnknownDim(), &output)); + c->set_output(0, output); + return Status::OK(); + }) + .Doc(R"doc( +Perform an MPI Allgather on a tensor. All other processes that do a gather on a +tensor with the same name must have the same rank for that tensor, and have the +same dimension on all but the first dimension. + +Arguments + tensor: A tensor to gather. + +Output + gathered: A tensor with the same shape as `tensor` except for the first dimension. +)doc"); + +} // namespace mpi +} // namespace contrib +} // namespace tensorflow diff --git a/tensorflow/contrib/mpi/mpi_ops.py b/tensorflow/contrib/mpi/mpi_ops.py new file mode 100644 index 000000000..8e7734a14 --- /dev/null +++ b/tensorflow/contrib/mpi/mpi_ops.py @@ -0,0 +1,154 @@ +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================= +"""Inter-process communication using MPI.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.python.framework import errors +from tensorflow.python.framework import load_library +from tensorflow.python.framework import ops +from tensorflow.python.platform import resource_loader +from tensorflow.python.platform import tf_logging as logging + + +def _load_library(name, op_list=None): + """Loads a .so file containing the specified operators. + + Args: + name: The name of the .so file to load. + op_list: A list of names of operators that the library should have. If None + then the .so file's contents will not be verified. + + Raises: + NameError if one of the required ops is missing. + """ + try: + filename = resource_loader.get_path_to_datafile(name) + library = load_library.load_op_library(filename) + for expected_op in (op_list or []): + for lib_op in library.OP_LIST.op: + if lib_op.name == expected_op: + break + else: + raise NameError( + 'Could not find operator %s in dynamic library %s' % + (expected_op, name)) + return library + except errors.NotFoundError: + logging.warning('%s file could not be loaded.', name) + + +MPI_LIB = _load_library('mpi.so', ['MPISize', 'MPIRank', 'MPILocalRank', + 'MPIAllgather', 'MPIAllreduce']) + + +def size(name=None): + """An op which returns the number of MPI processes. + + This is equivalent to running `MPI_Comm_size(MPI_COMM_WORLD, ...)` to get the + size of the global communicator. + + Returns: + An integer scalar containing the number of MPI processes. + """ + return MPI_LIB.mpi_size(name=name) + + +ops.NotDifferentiable('MPISize') + + +def rank(name=None): + """An op which returns the MPI rank of the calling process. + + This is equivalent to running `MPI_Comm_rank(MPI_COMM_WORLD, ...)` to get the + rank of the current process in the global communicator. + + Returns: + An integer scalar with the MPI rank of the calling process. + """ + return MPI_LIB.mpi_rank(name=name) + + +ops.NotDifferentiable('MPIRank') + + +def init(name=None): + """An op which initializes MPI on the device on which it is run. + + All future MPI ops must be run on the same device that the `init` op was run + on. + """ + return MPI_LIB.mpi_init(name=name) + + +ops.NotDifferentiable('MPIInit') + + +def local_rank(name=None): + """An op which returns the local MPI rank of the calling process, within the + node that it is running on. For example, if there are seven processes running + on a node, their local ranks will be zero through six, inclusive. + + This is equivalent to running `MPI_Comm_rank(...)` on a new communicator + which only includes processes on the same node. + + Returns: + An integer scalar with the local MPI rank of the calling process. + """ + return MPI_LIB.mpi_local_rank(name=name) + + +ops.NotDifferentiable('MPILocalRank') + + +def _allreduce(tensor, name=None): + """An op which sums an input tensor over all the MPI processes. + + The reduction operation is keyed by the name of the op. The tensor type and + shape must be the same on all MPI processes for a given name. The reduction + will not start until all processes are ready to send and receive the tensor. + + Returns: + A tensor of the same shape and type as `tensor`, summed across all + processes. + """ + return MPI_LIB.mpi_allreduce(tensor, name=name) + + +ops.NotDifferentiable('MPIAllreduce') + + +def allgather(tensor, name=None): + """An op which concatenates the input tensor with the same input tensor on + all other MPI processes. + + The concatenation is done on the first dimension, so the input tensors on the + different processes must have the same rank and shape, except for the first + dimension, which is allowed to be different. + + Returns: + A tensor of the same type as `tensor`, concatenated on dimension zero + across all processes. The shape is identical to the input shape, except for + the first dimension, which may be greater and is the sum of all first + dimensions of the tensors in different MPI processes. + """ + return MPI_LIB.mpi_allgather(tensor, name=name) + + +ops.NotDifferentiable('MPIAllgather') + + diff --git a/tensorflow/contrib/mpi/mpi_ops_test.py b/tensorflow/contrib/mpi/mpi_ops_test.py new file mode 100644 index 000000000..84953efc2 --- /dev/null +++ b/tensorflow/contrib/mpi/mpi_ops_test.py @@ -0,0 +1,301 @@ +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================= + +"""Tests for tensorflow.contrib.mpi.mpi_ops.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os.path +import itertools + +import tensorflow as tf + +import tensorflow.contrib.mpi as mpi + + +def mpi_env_rank_and_size(): + """Get MPI rank and size from environment variables and return them as a + tuple of integers. + + Most MPI implementations have an `mpirun` or `mpiexec` command that will + run an MPI executable and set up all communication necessary between the + different processors. As part of that set up, they will set environment + variables that contain the rank and size of the MPI_COMM_WORLD + communicator. We can read those environment variables from Python in order + to ensure that `mpi.rank()` and `mpi.size()` return the expected values. + + Since MPI is just a standard, not an implementation, implementations + typically choose their own environment variable names. This function tries + to support several different implementation, but really it only needs to + support whatever implementation we want to use for the TensorFlow test + suite. + + If this is not running under MPI, then defaults of rank zero and size one + are returned. (This is appropriate because when you call MPI_Init in an + application not started with mpirun, it will create a new independent + communicator with only one process in it.) + """ + rank_env = "PMI_RANK OMPI_COMM_WORLD_RANK".split() + size_env = "PMI_SIZE OMPI_COMM_WORLD_SIZE".split() + + for rank_var, size_var in zip(rank_env, size_env): + rank = os.environ.get(rank_var) + size = os.environ.get(size_var) + if rank is not None and size is not None: + return int(rank), int(size) + + # Default to rank zero and size one if there are no environment variables + return 0, 1 + + +class MPITests(tf.test.TestCase): + """ + Tests for MPI ops in tensorflow.contrib.mpi. + """ + + def test_mpi_rank(self): + """Test that the rank returned by mpi.rank() is correct.""" + true_rank, _ = mpi_env_rank_and_size() + with self.test_session() as session: + rank = session.run(mpi.rank()) + self.assertEqual(true_rank, rank) + + def test_mpi_size(self): + """Test that the size returned by mpi.size() is correct.""" + _, true_size = mpi_env_rank_and_size() + with self.test_session() as session: + size = session.run(mpi.size()) + self.assertEqual(true_size, size) + + def test_mpi_allreduce_cpu(self): + """Test on CPU that the allreduce correctly sums 1D, 2D, 3D tensors.""" + with self.test_session() as session: + size = session.run(mpi.size()) + + dtypes = [tf.int32, tf.float32] + dims = [1, 2, 3] + for dtype, dim in itertools.product(dtypes, dims): + tf.set_random_seed(1234) + tensor = tf.random_uniform([17] * dim, -100, 100, + dtype=dtype) + summed = mpi.allreduce(tensor, average=False) + multiplied = tensor * size + max_difference = tf.reduce_max(tf.abs(summed - multiplied)) + + # Threshold for floating point equality depends on number of + # ranks, since we're comparing against precise multiplication. + if size <= 3: + threshold = 0 + elif size < 10: + threshold = 1e-4 + elif size < 15: + threshold = 5e-4 + else: + break + + diff = session.run(max_difference) + self.assertTrue(diff <= threshold, + "mpi.allreduce produces incorrect results") + + def test_mpi_allreduce_gpu(self): + """Test that the allreduce works on GPUs. + + This test will crash badly if used with an MPI implementation that does + not support GPU memory transfers directly, as it will call MPI_Send on + a GPU data pointer.""" + # Only do this test if there are GPUs available. + if not tf.test.is_gpu_available(cuda_only=True): + return + + no_gpus = tf.GPUOptions(visible_device_list="") + cpu_config = tf.ConfigProto(gpu_options=no_gpus) + with self.test_session(config=cpu_config) as session: + local_rank = session.run(mpi.local_rank()) + + one_gpu = tf.GPUOptions(visible_device_list=str(local_rank)) + gpu_config = tf.ConfigProto(gpu_options=one_gpu) + with self.test_session(config=gpu_config) as session: + size = session.run(mpi.size()) + + dtype = tf.float32 + dim = 3 + with tf.device("/gpu:0"): + tf.set_random_seed(1234) + tensor = tf.random_uniform([17] * dim, -100, 100, dtype=dtype) + summed = mpi.allreduce(tensor, average=False) + multiplied = tensor * size + max_difference = tf.reduce_max(tf.abs(summed - multiplied)) + + # Threshold for floating point equality depends on number of + # ranks, since we're comparing against precise multiplication. + if size <= 3: + threshold = 0 + elif size < 10: + threshold = 1e-4 + elif size < 15: + threshold = 5e-4 + else: + return + + diff = session.run(max_difference) + self.assertTrue(diff <= threshold, + "mpi.allreduce on GPU produces incorrect results") + + def test_mpi_allreduce_error(self): + """Test that the allreduce raises an error if different ranks try to + send tensors of different rank or dimension.""" + with self.test_session() as session: + rank = session.run(mpi.rank()) + size = session.run(mpi.size()) + + # This test does not apply if there is only one worker. + if size == 1: + return + + # Same rank, different dimension + tf.set_random_seed(1234) + dims = [17 + rank] * 3 + tensor = tf.random_uniform(dims, -1.0, 1.0) + with self.assertRaises(tf.errors.FailedPreconditionError): + session.run(mpi.allreduce(tensor)) + + # Same number of elements, different rank + tf.set_random_seed(1234) + if rank == 0: + dims = [17, 23 * 57] + else: + dims = [17, 23, 57] + tensor = tf.random_uniform(dims, -1.0, 1.0) + with self.assertRaises(tf.errors.FailedPreconditionError): + session.run(mpi.allreduce(tensor)) + + def test_mpi_allreduce_type_error(self): + """Test that the allreduce raises an error if different ranks try to + send tensors of different type.""" + with self.test_session() as session: + rank = session.run(mpi.rank()) + size = session.run(mpi.size()) + + # This test does not apply if there is only one worker. + if size == 1: + return + + # Same rank, different dimension + dims = [17] * 3 + tensor = tf.ones(dims, + dtype=tf.int32 if rank % 2 == 0 else tf.float32) + with self.assertRaises(tf.errors.FailedPreconditionError): + session.run(mpi.allreduce(tensor)) + + def test_mpi_allgather(self): + """Test that the allgather correctly gathers 1D, 2D, 3D tensors.""" + with self.test_session() as session: + size = session.run(mpi.size()) + rank = session.run(mpi.rank()) + + dtypes = tf.int32, tf.float32 + dims = 1, 2, 3 + for dtype, dim in itertools.product(dtypes, dims): + tensor = tf.ones([17] * dim, dtype=dtype) * rank + gathered = mpi.allgather(tensor) + + gathered_tensor = session.run(gathered) + self.assertEqual(list(gathered_tensor.shape), + [17 * size] + [17] * (dim - 1)) + + for i in range(size): + rank_tensor = tf.slice(gathered_tensor, + [i * 17] + [0] * (dim - 1), + [17] + [-1] * (dim - 1)) + self.assertEqual(list(rank_tensor.shape), [17] * dim) + self.assertTrue( + session.run(tf.reduce_all(tf.equal(rank_tensor, i))), + "mpi.allgather produces incorrect gathered tensor") + + def test_mpi_allgather_variable_size(self): + """Test that the allgather correctly gathers 1D, 2D, 3D tensors, + even if those tensors have different sizes along the first dim.""" + with self.test_session() as session: + size = session.run(mpi.size()) + rank = session.run(mpi.rank()) + + dtypes = tf.int32, tf.float32 + dims = 1, 2, 3 + for dtype, dim in itertools.product(dtypes, dims): + # Support tests up to MPI Size of 35 + if size > 35: + break + + tensor_sizes = [17, 32, 81, 12, 15, 23, 22] * 5 + tensor_sizes = tensor_sizes[:size] + + tensor = tf.ones([tensor_sizes[rank]] + [17] * (dim - 1), + dtype=dtype) * rank + gathered = mpi.allgather(tensor) + + gathered_tensor = session.run(gathered) + expected_size = sum(tensor_sizes) + self.assertEqual(list(gathered_tensor.shape), + [expected_size] + [17] * (dim - 1)) + + for i in range(size): + rank_size = [tensor_sizes[i]] + [17] * (dim - 1) + rank_tensor = tf.slice( + gathered, [sum(tensor_sizes[:i])] + [0] * (dim - 1), + rank_size) + self.assertEqual(list(rank_tensor.shape), rank_size) + self.assertTrue( + session.run(tf.reduce_all(tf.equal(rank_tensor, i))), + "mpi.allgather produces incorrect gathered tensor") + + def test_mpi_allgather_error(self): + """Test that the allgather returns an error if any dimension besides + the first is different among the tensors being gathered.""" + with self.test_session() as session: + rank = session.run(mpi.rank()) + size = session.run(mpi.size()) + + # This test does not apply if there is only one worker. + if size == 1: + return + + tensor_size = [17] * 3 + tensor_size[1] = 10 * (rank + 1) + tensor = tf.ones(tensor_size, dtype=tf.float32) * rank + with self.assertRaises(tf.errors.FailedPreconditionError): + session.run(mpi.allgather(tensor)) + + def test_mpi_allgather_type_error(self): + """Test that the allgather returns an error if the types being gathered + differ among the processes""" + with self.test_session() as session: + rank = session.run(mpi.rank()) + size = session.run(mpi.size()) + + # This test does not apply if there is only one worker. + if size == 1: + return + + tensor_size = [17] * 3 + dtype = tf.int32 if rank % 2 == 0 else tf.float32 + tensor = tf.ones(tensor_size, dtype=dtype) * rank + with self.assertRaises(tf.errors.FailedPreconditionError): + session.run(mpi.allgather(tensor)) + + +if __name__ == '__main__': + tf.test.main() diff --git a/tensorflow/contrib/mpi/ring.cc b/tensorflow/contrib/mpi/ring.cc new file mode 100644 index 000000000..2c5da128f --- /dev/null +++ b/tensorflow/contrib/mpi/ring.cc @@ -0,0 +1,48 @@ +#define EIGEN_USE_THREADS + +#include "tensorflow/contrib/mpi/ring.h" + +namespace tensorflow { +namespace contrib { +namespace mpi { + +using CPUDevice = Eigen::ThreadPoolDevice; + +extern template MPI_Datatype MPIType(); +extern template MPI_Datatype MPIType(); +extern template DataType TensorFlowDataType(); +extern template DataType TensorFlowDataType(); + + +// Generate all necessary specializations for RingAllreduce. +template Status RingAllreduce(OpKernelContext*, Tensor&, Tensor*); +template Status RingAllreduce(OpKernelContext*, Tensor&, Tensor*); + +// Generate all necessary specializations for RingAllgather. +template Status RingAllgather( + OpKernelContext*, Tensor&, Tensor*, std::vector&); +template Status RingAllgather( + OpKernelContext*, Tensor&, Tensor*, std::vector&); + +// Copy data on a CPU using a straight-forward memcpy. +template<> void CopyTensorData(void* dst, void* src, size_t size) { + std::memcpy(dst, src, size); +}; + +// Accumulate values on a CPU. +template<> void AccumulateTensorData( + float* dst, float* src, size_t size) { + for(unsigned int i = 0; i < size; i++) { + dst[i] += src[i]; + } +}; +template<> void AccumulateTensorData( + int* dst, int* src, size_t size) { + for(unsigned int i = 0; i < size; i++) { + dst[i] += src[i]; + } +} + +} +} +} diff --git a/tensorflow/contrib/mpi/ring.cu.cc b/tensorflow/contrib/mpi/ring.cu.cc new file mode 100644 index 000000000..99795e7c7 --- /dev/null +++ b/tensorflow/contrib/mpi/ring.cu.cc @@ -0,0 +1,68 @@ + +#if GOOGLE_CUDA + +#define EIGEN_USE_GPU + +#include "tensorflow/contrib/mpi/ring.h" + +namespace tensorflow { +namespace contrib { +namespace mpi { + +using CPUDevice = Eigen::ThreadPoolDevice; + +template<> MPI_Datatype MPIType() { return MPI_FLOAT; }; +template<> MPI_Datatype MPIType() { return MPI_INT; }; + +template<> DataType TensorFlowDataType() { return DT_FLOAT; }; +template<> DataType TensorFlowDataType() { return DT_INT32; }; + +// Generate all necessary specializations for RingAllreduce. +template Status RingAllreduce(OpKernelContext*, Tensor&, Tensor*); +template Status RingAllreduce(OpKernelContext*, Tensor&, Tensor*); + +// Generate all necessary specializations for RingAllgather. +template Status RingAllgather( + OpKernelContext*, Tensor&, Tensor*, std::vector&); +template Status RingAllgather( + OpKernelContext*, Tensor&, Tensor*, std::vector&); + +// Synchronously copy data on the GPU, using a different stream than the default +// and than TensorFlow to avoid synchronizing on operations unrelated to the +// allreduce. +template<> void CopyTensorData(void* dst, void* src, size_t size) { + auto stream = CudaStreamForMPI(); + cudaMemcpyAsync(dst, src, size, cudaMemcpyDeviceToDevice, stream); + cudaStreamSynchronize(stream); +}; + +// Elementwise accumulation kernel for GPU. +template +__global__ void elemwise_accum(T* out, const T* in, const size_t N) { + for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; + i < N; + i += blockDim.x * gridDim.x) { + out[i] += in[i]; + } +} + +// Synchronously accumulate tensors on the GPU, using a different stream than +// the default and than TensorFlow to avoid synchronizing on operations +// unrelated to the allreduce. +template<> void AccumulateTensorData( + float* dst, float* src, size_t size) { + auto stream = CudaStreamForMPI(); + elemwise_accum<<<32, 256, 0, stream>>>(dst, src, size); + cudaStreamSynchronize(stream); +}; +template<> void AccumulateTensorData( + int* dst, int* src, size_t size) { + auto stream = CudaStreamForMPI(); + elemwise_accum<<<32, 256, 0, stream>>>(dst, src, size); + cudaStreamSynchronize(stream); +}; + +} +} +} +#endif diff --git a/tensorflow/contrib/mpi/ring.h b/tensorflow/contrib/mpi/ring.h new file mode 100644 index 000000000..091cd7a8d --- /dev/null +++ b/tensorflow/contrib/mpi/ring.h @@ -0,0 +1,330 @@ +#ifndef TENSORFLOW_CONTRIB_MPI_H_ +#define TENSORFLOW_CONTRIB_MPI_H_ + +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/shape_inference.h" + +#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" +#include "tensorflow/core/framework/tensor_types.h" + +#if GOOGLE_CUDA +#include "cuda_runtime.h" +#endif + +// Needed to avoid header issues with C++-supporting MPI implementations +#define OMPI_SKIP_MPICXX +#include "third_party/mpi/mpi.h" + +#define TAG_TENSOR 12 + +namespace tensorflow { +namespace contrib { +namespace mpi { + +using CPUDevice = Eigen::ThreadPoolDevice; +using GPUDevice = Eigen::GpuDevice; + +// Convert from templated types to values we can pass to MPI. +template +MPI_Datatype MPIType(); + +// Convert from templated types to TensorFlow data types. +template +DataType TensorFlowDataType(); + +#define MPI_REQUIRES_OK(MPI_STATUS) \ + if((MPI_STATUS) != MPI_SUCCESS) { \ + return errors::Unknown("MPI operation failed unexpectedly."); \ + } + +// Copy data from one tensor to another tensor. +// This uses a custom CUDA stream on GPU, which is necessary to overlay the +// backpropagation computations with the allreduce. +template +void CopyTensorData(void* destination, void* source, size_t size); + +// Add a tensor into another tensor, accumulating in place. +// This uses a custom CUDA stream on GPU, which is necessary to overlay the +// backpropagation computations with the allreduce. +template +void AccumulateTensorData(T* destination, T* source, size_t size); + +// We need to get the right stream for doing CUDA memory transfers and +// operations, which is possibly different from the standard TensorFlow stream. +#if GOOGLE_CUDA +cudaStream_t CudaStreamForMPI(); +#endif + +/* Perform a ring allreduce on the data. Allocate the necessary output tensor and + * store it in the output parameter. + * + * Assumes that all MPI processes are doing an allreduce of the same tensor, + * with the same dimensions. + * + * A ring allreduce is a bandwidth-optimal way to do an allreduce. To do the allreduce, + * the nodes involved are arranged in a ring: + * + * .--0--. + * / \ + * 3 1 + * \ / + * *--2--* + * + * Each node always sends to the next clockwise node in the ring, and receives + * from the previous one. + * + * The allreduce is done in two parts: a scatter-reduce and an allgather. In + * the scatter reduce, a reduction is done, so that each node ends up with a + * chunk of the final output tensor which has contributions from all other + * nodes. In the allgather, those chunks are distributed among all the nodes, + * so that all nodes have the entire output tensor. + * + * Both of these operations are done by dividing the input tensor into N + * evenly sized chunks (where N is the number of nodes in the ring). + * + * The scatter-reduce is done in N-1 steps. In the ith step, node j will send + * the (j - i)th chunk and receive the (j - i - 1)th chunk, adding it in to + * its existing data for that chunk. For example, in the first iteration with + * the ring depicted above, you will have the following transfers: + * + * Segment 0: Node 0 --> Node 1 + * Segment 1: Node 1 --> Node 2 + * Segment 2: Node 2 --> Node 3 + * Segment 3: Node 3 --> Node 0 + * + * In the second iteration, you'll have the following transfers: + * + * Segment 0: Node 1 --> Node 2 + * Segment 1: Node 2 --> Node 3 + * Segment 2: Node 3 --> Node 0 + * Segment 3: Node 0 --> Node 1 + * + * After this iteration, Node 2 has 3 of the four contributions to Segment 0. + * The last iteration has the following transfers: + * + * Segment 0: Node 2 --> Node 3 + * Segment 1: Node 3 --> Node 0 + * Segment 2: Node 0 --> Node 1 + * Segment 3: Node 1 --> Node 2 + * + * After this iteration, Node 3 has the fully accumulated Segment 0; Node 0 + * has the fully accumulated Segment 1; and so on. The scatter-reduce is complete. + * + * Next, the allgather distributes these fully accumululated chunks across all nodes. + * Communication proceeds in the same ring, once again in N-1 steps. At the ith step, + * node j will send chunk (j - i + 1) and receive chunk (j - i). For example, at the + * first iteration, the following transfers will occur: + * + * Segment 0: Node 3 --> Node 0 + * Segment 1: Node 0 --> Node 1 + * Segment 2: Node 1 --> Node 2 + * Segment 3: Node 2 --> Node 3 + * + * After the first iteration, Node 0 will have a fully accumulated Segment 0 + * (from Node 3) and Segment 1. In the next iteration, Node 0 will send its + * just-received Segment 0 onward to Node 1, and receive Segment 3 from Node 3. + * After this has continued for N - 1 iterations, all nodes will have a the fully + * accumulated tensor. + * + * Each node will do (N-1) sends for the scatter-reduce and (N-1) sends for the allgather. + * Each send will contain K / N bytes, if there are K bytes in the original tensor on every node. + * Thus, each node sends and receives 2K(N - 1)/N bytes of data, and the performance of the allreduce + * (assuming no latency in connections) is constrained by the slowest interconnect between the nodes. + * + */ +template +Status RingAllreduce(OpKernelContext* context, Tensor& input, Tensor* output) { + // Acquire MPI size and rank + int n, r; + MPI_REQUIRES_OK(MPI_Comm_size(MPI_COMM_WORLD, &n)); + MPI_REQUIRES_OK(MPI_Comm_rank(MPI_COMM_WORLD, &r)); + + // Allocate a new output tensor and copy data to it. + Status status = context->allocate_temp(TensorFlowDataType(), input.shape(), output); + if(!status.ok()) { + return status; + } + T* buffer = (T*) output->tensor_data().data(); + CopyTensorData((void*) buffer, + (void*) input.tensor_data().data(), + output->tensor_data().size()); + + // Calculate segment sizes and segment ends + const size_t elements_to_reduce = input.NumElements(); + const size_t segment_size = elements_to_reduce / n; + std::vector segment_sizes(n, segment_size); + + const size_t residual = elements_to_reduce % n; + for (size_t i = 0; i < residual; ++i) { + segment_sizes[i]++; + } + + std::vector segment_ends(n); + segment_ends[0] = segment_sizes[0]; + for (size_t i = 1; i < segment_ends.size(); ++i) { + segment_ends[i] = segment_sizes[i] + segment_ends[i-1]; + } + + assert(segment_ends[n-1] == elements_to_reduce); + + // Allocate temporary buffer - we know the first segment size is the + // largest. + tensorflow::TensorShape shape; + tensorflow::Tensor temp; + shape.AddDim(segment_sizes[0]); + status = context->allocate_temp(TensorFlowDataType(), shape, &temp); + if(!status.ok()) { + return status; + } + T* segment_recv = (T*) temp.tensor_data().data(); + + // Receive from your left neighbor with wrap-around + const size_t recv_from = ((r - 1) + n) % n; + + // Send to your right neighbor with wrap-around + const size_t send_to = (r + 1) % n; + + MPI_Status recv_status; + MPI_Request recv_req; + + // Now start ring. At every step, for every rank, we iterate through + // segments with wraparound and send and recv from our neighbors and reduce + // locally. At the i'th iteration, rank r, sends segment (r-i) and receives + // segment (r-i-1). + for (int i = 0; i < n - 1; i++) { + T* segment_send = &(buffer[segment_ends[((r-i) + n) % n] - + segment_sizes[((r-i) + n) % n]]); + + MPI_REQUIRES_OK(MPI_Irecv(segment_recv, segment_sizes[((r-i-1) + n) % n], + MPIType(), recv_from, TAG_TENSOR, MPI_COMM_WORLD, &recv_req)); + + MPI_REQUIRES_OK(MPI_Send(segment_send, segment_sizes[((r-i) + n) % n], + MPIType(), send_to, TAG_TENSOR, MPI_COMM_WORLD)); + + T *segment_update = &(buffer[segment_ends[((r-i-1) + n) % n] - + segment_sizes[((r-i-1) + n) % n]]); + + // Wait for recv to complete before reduction + MPI_Wait(&recv_req, &recv_status); + + const int N = segment_sizes[((r-i-1) + n) % n]; + auto recv = temp.Slice(0, segment_sizes[((r-i-1) + n) % n]); + AccumulateTensorData( + segment_update, (T*) recv.tensor_data().data(), N); + } + + // Now start pipelined ring allgather. At every step, for every rank, we + // iterate through segments with wraparound and send and recv from our + // neighbors. At the i'th iteration, rank r, sends segment (r+1-i) and + // receives segment (r-i). + for (size_t i = 0; i < n - 1; ++i) { + // Segment to send - at every iteration we send segment (r+1-i) + T* segment_send = &(buffer[segment_ends[((r+1-i) + n) % n] - + segment_sizes[((r+1-i) + n) % n]]); + + // Segment to recv - at every iteration we receive segment (r-i) + T* segment_recv = &(buffer[segment_ends[((r-i) + n) % n] - + segment_sizes[((r-i) + n) % n]]); + MPI_REQUIRES_OK(MPI_Sendrecv(segment_send, segment_sizes[((r+1-i) + n) % n], + MPIType(), send_to, TAG_TENSOR, segment_recv, + segment_sizes[((r-i) + n) % n], MPIType(), recv_from, + TAG_TENSOR, MPI_COMM_WORLD, &recv_status)); + } + + return Status::OK(); +} + +// Perform a ring allgather on a Tensor. Other ranks may allgather with a +// tensor which differs in the first dimension only; all other dimensions must +// be the same. +// +// For more information on the ring allgather, read the documentation for the +// ring allreduce, which includes a ring allgather. +template +Status RingAllgather(OpKernelContext* context, Tensor& input, Tensor* output, + std::vector& sizes) { + // Acquire MPI size and rank + int n, r; + MPI_REQUIRES_OK(MPI_Comm_size(MPI_COMM_WORLD, &n)); + MPI_REQUIRES_OK(MPI_Comm_rank(MPI_COMM_WORLD, &r)); + + assert(sizes.size() == n); + assert(input.dim_size(0) == sizes[r]); + + // Compute output shape: all dimensions identical, except first, which is + // the sum of all the input tensor sizes. + size_t total_dimension_size = 0; + for(auto dim : sizes) { + total_dimension_size += dim; + } + + tensorflow::TensorShape output_shape; + output_shape.AddDim(total_dimension_size); + for(int i = 1; i < input.shape().dims(); i++) { + output_shape.AddDim(input.dim_size(i)); + } + + // Compute number of elements in every "row". We can't compute number of + // elements in every chunks, because those chunks are variable length. + size_t elements_per_row = 1; + for(int i = 1; i < input.shape().dims(); i++) { + elements_per_row *= input.dim_size(i); + } + + Status status = context->allocate_temp(TensorFlowDataType(), output_shape, output); + if(!status.ok()) { + return status; + } + + // Copy data from input tensor to correct place in output tensor. + std::vector segment_starts(sizes.size()); + segment_starts[0] = 0; + for(int i = 1; i < n; i++) { + segment_starts[i] = segment_starts[i - 1] + elements_per_row * sizes[i - 1]; + } + size_t offset = segment_starts[r]; + + // Copy data to the right offset for this rank. + T* buffer = (T*) output->tensor_data().data(); + CopyTensorData((void*) (buffer + offset), + (void*) input.tensor_data().data(), + elements_per_row * sizes[r] * sizeof(T)); + + // Receive from your left neighbor with wrap-around + const size_t recv_from = ((r - 1) + n) % n; + + // Send to your right neighbor with wrap-around + const size_t send_to = (r + 1) % n; + + // Perform a ring allgather. At every step, for every rank, we iterate + // through segments with wraparound and send and recv from our neighbors. + // At the i'th iteration, rank r, sends segment (r-i) and receives segment (r-1-i). + MPI_Status recv_status; + for (size_t i = 0; i < n - 1; ++i) { + // Segment to send - at every iteration we send segment (r-i) + size_t offset_send = segment_starts[(r - i + n) % n]; + size_t rows_send = sizes[(r - i + n) % n]; + T* segment_send = &(buffer[offset_send]); + + // Segment to recv - at every iteration we receive segment (r-1-i) + size_t offset_recv = segment_starts[(r - i - 1 + n) % n]; + size_t rows_recv = sizes[(r - i - 1 + n) % n]; + T* segment_recv = &(buffer[offset_recv]); + + MPI_REQUIRES_OK(MPI_Sendrecv( + segment_send, elements_per_row * rows_send, + MPIType(), send_to, TAG_TENSOR, segment_recv, + elements_per_row * rows_recv, MPIType(), recv_from, + TAG_TENSOR, MPI_COMM_WORLD, &recv_status)); + } + + return Status::OK(); +} + +} +} +} + +#undef TENSORFLOW_CONTRIB_MPI_H_ +#endif // TENSORFLOW_CONTRIB_MPI_H_ diff --git a/tensorflow/tools/ci_build/Dockerfile.cpu.mpi b/tensorflow/tools/ci_build/Dockerfile.cpu.mpi new file mode 100644 index 000000000..d4a75a61f --- /dev/null +++ b/tensorflow/tools/ci_build/Dockerfile.cpu.mpi @@ -0,0 +1,24 @@ +FROM ubuntu:14.04 + +MAINTAINER Andrew Gibiansky + +# Copy and run the install scripts. +COPY install/*.sh /install/ +RUN /install/install_bootstrap_deb_packages.sh +RUN add-apt-repository -y ppa:openjdk-r/ppa && \ + add-apt-repository -y ppa:mc3man/trusty-media && \ + add-apt-repository -y ppa:george-edison55/cmake-3.x +RUN /install/install_deb_packages.sh +RUN /install/install_pip_packages.sh +RUN /install/install_bazel.sh +RUN /install/install_proto3.sh +RUN /install/install_buildifier.sh +RUN /install/install_mpi.sh + +# Set up bazelrc. +COPY install/.bazelrc /root/.bazelrc +ENV BAZELRC /root/.bazelrc + +# Set up MPI +ENV TF_NEED_MPI 1 +ENV MPI_PATH /usr/lib/openmpi diff --git a/tensorflow/tools/ci_build/install/install_mpi.sh b/tensorflow/tools/ci_build/install/install_mpi.sh new file mode 100755 index 000000000..6ee9d7659 --- /dev/null +++ b/tensorflow/tools/ci_build/install/install_mpi.sh @@ -0,0 +1,23 @@ +#!/usr/bin/env bash +# Copyright 2015 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +set +e +mpiexec=$(which mpiexec) +if [[ -z "$mpiexec_location" ]]; then + # Install dependencies from ubuntu deb repository. + apt-get update + apt-get install -y --no-install-recommends openmpi-bin libopenmpi-dev +fi diff --git a/third_party/mpi/.gitignore b/third_party/mpi/.gitignore new file mode 100644 index 000000000..ab011617a --- /dev/null +++ b/third_party/mpi/.gitignore @@ -0,0 +1,3 @@ +*.h +*.dylib +*.so diff --git a/third_party/mpi/BUILD b/third_party/mpi/BUILD new file mode 100644 index 000000000..cbae02431 --- /dev/null +++ b/third_party/mpi/BUILD @@ -0,0 +1,26 @@ +package(default_visibility = ["//visibility:public"]) + +licenses(["notice"]) # Apache 2.0 + +exports_files(["LICENSE.txt"]) + +filegroup( + name = "all_files", + srcs = glob( + ["**/*"], + exclude = [ + "**/METADATA", + "**/OWNERS", + ], + ), + visibility = ["//tensorflow:__subpackages__"], +) + +cc_library( + name = "mpi", + srcs = select({ + "//tensorflow:darwin": ["libmpi.dylib"], + "//conditions:default": ["libmpi.so"], + }), + hdrs = ["mpi.h", "mpi_portable_platform.h"], +)