From 21621e5bf4920e26c1e0e83f5bacf74b19322aee Mon Sep 17 00:00:00 2001 From: Andrew Gibiansky Date: Wed, 14 Dec 2016 00:51:16 -0800 Subject: [PATCH] Introduce MPI allreduce in a new contrib project. This commit adds the tensorflow.contrib.mpi namespace and contrib project, which has a variety of ops that work with MPI. The MPI system works by starting a background thread which communicates between the different processes at a regular interval and schedules asynchronous reductions. At every tick, every rank will notify rank zero of the tensors it is ready to reduce, signifying completion with an empty DONE message. Rank zero will count how many ranks are ready to reduce every tensor, and, whenever a tensor is ready to reduce (that is, every rank is ready to reduce it), rank zero will issue a message to all other ranks directing them to reduce that tensor. This repeats for all the tensors that are ready to reduce, after which rank zero sends all other ranks a DONE message indicating that the tick is complete. --- configure | 63 + tensorflow/contrib/BUILD | 5 +- tensorflow/contrib/mpi/BUILD | 66 + tensorflow/contrib/mpi/README.md | 5 + tensorflow/contrib/mpi/__init__.py | 281 ++++ tensorflow/contrib/mpi/mpi_message.proto | 66 + tensorflow/contrib/mpi/mpi_ops.cc | 1147 +++++++++++++++++ tensorflow/contrib/mpi/mpi_ops.py | 154 +++ tensorflow/contrib/mpi/mpi_ops_test.py | 301 +++++ tensorflow/contrib/mpi/ring.cc | 48 + tensorflow/contrib/mpi/ring.cu.cc | 68 + tensorflow/contrib/mpi/ring.h | 330 +++++ tensorflow/tools/ci_build/Dockerfile.cpu.mpi | 24 + .../tools/ci_build/install/install_mpi.sh | 23 + third_party/mpi/.gitignore | 3 + third_party/mpi/BUILD | 26 + 16 files changed, 2609 insertions(+), 1 deletion(-) create mode 100644 tensorflow/contrib/mpi/BUILD create mode 100644 tensorflow/contrib/mpi/README.md create mode 100644 tensorflow/contrib/mpi/__init__.py create mode 100644 tensorflow/contrib/mpi/mpi_message.proto create mode 100644 tensorflow/contrib/mpi/mpi_ops.cc create mode 100644 tensorflow/contrib/mpi/mpi_ops.py create mode 100644 tensorflow/contrib/mpi/mpi_ops_test.py create mode 100644 tensorflow/contrib/mpi/ring.cc create mode 100644 tensorflow/contrib/mpi/ring.cu.cc create mode 100644 tensorflow/contrib/mpi/ring.h create mode 100644 tensorflow/tools/ci_build/Dockerfile.cpu.mpi create mode 100755 tensorflow/tools/ci_build/install/install_mpi.sh create mode 100644 third_party/mpi/.gitignore create mode 100644 third_party/mpi/BUILD diff --git a/configure b/configure index 87ef6e99b..22c36c9bf 100755 --- a/configure +++ b/configure @@ -63,6 +63,7 @@ if is_windows; then TF_NEED_HDFS=0 TF_NEED_JEMALLOC=0 TF_NEED_OPENCL=0 + TF_NEED_MPI=0 fi while [ "$TF_NEED_JEMALLOC" == "" ]; do @@ -112,6 +113,68 @@ else sed -i -e "s/WITH_GCP_SUPPORT = True/WITH_GCP_SUPPORT = False/" tensorflow/core/platform/default/build_config.bzl fi +while [ "$TF_NEED_MPI" == "" ]; do + read -p "Do you wish to build TensorFlow with "\ +"Message Passing Interface (MPI) support? [y/N] " INPUT + case $INPUT in + [Yy]* ) echo "MPI support will be enabled for "\ +"TensorFlow"; TF_NEED_MPI=1;; + [Nn]* ) echo "No MPI support will be enabled for "\ +"TensorFlow"; TF_NEED_MPI=0;; + "" ) echo "No Hadoop File System support will be enabled for "\ +"TensorFlow"; TF_NEED_MPI=0;; + * ) echo "Invalid selection: " $INPUT;; + esac +done + +while true; do + if [ "$TF_NEED_MPI" == "0" ]; then + break; + fi + + fromuser="" + if [ -z "$MPI_PATH" ]; then + default_mpi_path=$(dirname $(dirname $(which mpirun)) || dirname $(dirname $(which mpiexec)) || true) + read -p "Please specify the location of MPI. [Default is $default_mpi_path]: " MPI_PATH + fromuser="1" + if [ -z "$MPI_PATH" ]; then + MPI_PATH=$default_mpi_path + fi + fi + if [ -e "$MPI_PATH/include" -a -e "$MPI_PATH/lib" ]; then + break + fi + echo "Invalid MPI path. ${MPI_PATH}/include or ${MPI_PATH}/lib cannot be found" 1>&2 + if [ -z "$fromuser" ]; then + exit 1 + fi + MPI_PATH="" + # Retry +done + +if [ "$TF_NEED_MPI" == "1" ]; then + # Symlink necessary parts of MPI into the third party directory. + ln -sf "${MPI_PATH}/include/mpi.h" third_party/mpi/mpi.h + + if [ -e "${MPI_PATH}/include/mpi_portable_platform.h" ]; then + ln -sf "${MPI_PATH}/include/mpi_portable_platform.h" third_party/mpi/mpi_portable_platform.h + fi + + if [ -e "${MPI_PATH}/lib/libmpi.so" ]; then + ln -sf "${MPI_PATH}/lib/libmpi.so" third_party/mpi/libmpi.so + fi + + if [ -e "${MPI_PATH}/lib/libmpi.dylib" ]; then + ln -sf "${MPI_PATH}/lib/libmpi.dylib" third_party/mpi/libmpi.dylib + fi + + # Update Bazel build configuration. + sed -i -e "s/WITH_MPI_SUPPORT = False/WITH_MPI_SUPPORT = True/" tensorflow/contrib/BUILD +else + # Update Bazel build configuration. + sed -i -e "s/WITH_MPI_SUPPORT = True/WITH_MPI_SUPPORT = False/" tensorflow/contrib/BUILD +fi + while [ "$TF_NEED_HDFS" == "" ]; do read -p "Do you wish to build TensorFlow with "\ "Hadoop File System support? [y/N] " INPUT diff --git a/tensorflow/contrib/BUILD b/tensorflow/contrib/BUILD index 680053ae1..beadaf0ac 100644 --- a/tensorflow/contrib/BUILD +++ b/tensorflow/contrib/BUILD @@ -1,6 +1,9 @@ # Description: # contains parts of TensorFlow that are experimental or unstable and which are not supported. +# The configure script may change this to True. +WITH_MPI_SUPPORT = True + licenses(["notice"]) # Apache 2.0 exports_files(["LICENSE"]) @@ -56,7 +59,7 @@ py_library( "//tensorflow/contrib/tfprof", "//tensorflow/contrib/training:training_py", "//tensorflow/contrib/util:util_py", - ], + ] + ["//tensorflow/contrib/mpi:mpi_ops_py"] if WITH_MPI_SUPPORT else [], ) cc_library( diff --git a/tensorflow/contrib/mpi/BUILD b/tensorflow/contrib/mpi/BUILD new file mode 100644 index 000000000..b75773327 --- /dev/null +++ b/tensorflow/contrib/mpi/BUILD @@ -0,0 +1,66 @@ +# Ops that communicate with other processes via MPI. + +licenses(["notice"]) # Apache 2.0 + +package(default_visibility = ["//tensorflow:__subpackages__"]) + +load("//tensorflow:tensorflow.bzl", "tf_gen_op_wrapper_py") +load("//tensorflow:tensorflow.bzl", "tf_copts") +load("//tensorflow:tensorflow.bzl", "tf_custom_op_library") +load("//tensorflow:tensorflow.bzl", "tf_py_test") +load("//tensorflow/core:platform/default/build_config.bzl", "tf_proto_library") + +tf_custom_op_library( + name = "mpi.so", + srcs = ["mpi_ops.cc", "ring.cc", "ring.h"], + gpu_srcs = ["ring.cu.cc", "ring.h"], + deps = [ + "//third_party/mpi:mpi", + ":mpi_message_proto_cc", + ], +) + +tf_py_test( + name = "mpi_ops_test", + srcs = ["mpi_ops_test.py"], + additional_deps = [ + "//tensorflow:tensorflow_py", + "//tensorflow/python:platform", + ], + data = [ + ":mpi.so", + ], + tags = ["manual"], +) + +py_library( + name = "mpi_ops_py", + srcs = [ + "__init__.py", + "mpi_ops.py", + ], + data = [ + ":mpi.so", + ], + srcs_version = "PY2AND3", + visibility = ["//visibility:public"], +) + +filegroup( + name = "all_files", + srcs = glob( + ["**/*"], + exclude = [ + "**/METADATA", + "**/OWNERS", + ], + ), + visibility = ["//tensorflow:__subpackages__"], +) + +tf_proto_library( + name = "mpi_message_proto", + srcs = ["mpi_message.proto"], + cc_api_version = 2, + visibility = ["//visibility:public"], +) diff --git a/tensorflow/contrib/mpi/README.md b/tensorflow/contrib/mpi/README.md new file mode 100644 index 000000000..c5e1a8c37 --- /dev/null +++ b/tensorflow/contrib/mpi/README.md @@ -0,0 +1,5 @@ +# MPI TensorFlow integration + +Tensorflow MPI integration allows communicating between different TensorFlow +processes using MPI. This enables training across multiple nodes and GPUs +using high-speed interconnects. diff --git a/tensorflow/contrib/mpi/__init__.py b/tensorflow/contrib/mpi/__init__.py new file mode 100644 index 000000000..79a562367 --- /dev/null +++ b/tensorflow/contrib/mpi/__init__.py @@ -0,0 +1,281 @@ +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +# pylint: disable=g-short-docstring-punctuation +"""## Communicating Between Processes with MPI + +TensorFlow natively provides inter-device communication through send and +receive ops and inter-node communication through Distributed TensorFlow, based +on the same send and receive abstractions. On HPC clusters where Infiniband or +other high-speed node interconnects are available, these can end up being +insufficient for synchronous data-parallel training (without asynchronous +gradient descent). This module implements a variety of MPI ops which can take +advantage of hardware-specific MPI libraries for efficient communication. + +In order to use this module, TensorFlow must be built with an MPI library, +which can be provided to the `./configure` script at build time. As a user of +TensorFlow, you will need to build TensorFlow yourself to select the MPI +library to use; to do so, follow the [instructions for building TensorFlow from +source](https://www.tensorflow.org/get_started/os_setup#installing_from_sources). + +### Utility Ops + +In addition to reductions and gathers, this module provides utility operations +for detecting the running MPI configuration. + +Example: + +```python +from tensorflow.contrib import mpi + +# Use `mpi.Session` instead of `tf.Session` +with mpi.Session() as session: + rank = session.run(mpi.rank()) + print("My MPI Rank:", rank) + + if rank == 0: + print("MPI Size:", session.run(mpi.size())) +``` + +@@rank +@@size + +### Ring Allreduce and Allgather + +When summing or averaging tensors across many processes, communication can +easily become a bottleneck. A naive implementation will send all the tensor +values to the same process, perform the reduction, and then broadcast the +values back to all other processes, effectively creating a synchronous +parameter server in one process. However, the process responsible for +performing the reduction will have to receive and send a massive amount of data +which scales with the number of processes *and* the number of parameters in the +model. + +Instead of centralizing the reduction and having one primary reducer, we can +implement a distributed allreduce or allgather. A bandwidth-optimal allreduce +will end up sending 2(N - 1) values for every value in the input tensor, +and can be implemented with a ring allreduce [1]. (Intuitively, a linear reduce +requires at least (N - 1) sends between the different nodes, and a broadcast of +the result also requires (N - 1) sends, for a total of 2 (N - 1); these two +steps cannot be combined in a clever way to reduce the number of required +sends.) This module implements bandwidth-optimal ring allreduce and ring +allgather operations using MPI; by choosing a hardware-appropriate MPI +implementation (such as OpenMPI with CUDA-IPC support), you can train large +models with synchronous gradient descent with minimal communication overhead. + +In addition to the `allreduce` and `allgather` functions, a convenience +`DistributedOptimizer` wrapper is provided to simplify using these functions +for reducing model gradients. + +Example: + +```python +import tensorflow as tf +from tensorflow.contrib import mpi + +# Construct a simple linear regression model to optimize +W = tf.get_variable("W", shape=[20, 1], dtype=tf.float32) +B = tf.get_variable("B", shape=[1, 1], dtype=tf.float32) +inputs = tf.placeholder("Inputs", shape=[None, 20]) +outputs = tf.placeholder("Outputs", shape=[None, 1]) +loss = tf.nn.l2_loss(tf.matmul(inputs, W) + B - outputs) + +# Training using MPI allreduce with DistributedOptimizer +optimizer = mpi.DistributedOptimizer(tf.train.AdamOptimizer()) +train = optimizer.minimize(loss) + +# Average loss over all ranks, for printing. +# Do not pass this to an optimizer! +avg_loss = mpi.allreduce(loss) + +# On different ranks, feed different input data. +with mpi.Session() as session: + rank = session.run(mpi.rank()) + batch_inputs, batch_outputs = construct_batch_for_rank(rank) + feed_dict = {inputs: batch_inputs, outputs: batch_outputs} + _, l = session.run([train, avg_loss], feed_dict=feed_dict) + print("Average Loss:", l) +``` + +[1] Patarasuk, Pitch and Yuan, Xin. "Bandwidth Optimal All-reduce Algorithms +for Clusters of Workstations". + +@@Session +@@DistributedOptimizer +@@allreduce +@@allgather +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import tensorflow as tf + +from tensorflow.contrib.mpi.mpi_ops import size +from tensorflow.contrib.mpi.mpi_ops import rank +from tensorflow.contrib.mpi.mpi_ops import local_rank +from tensorflow.contrib.mpi.mpi_ops import allgather +from tensorflow.contrib.mpi.mpi_ops import _allreduce +from tensorflow.contrib.mpi.mpi_ops import init + + +def allreduce(tensor, average=True): + """Perform an MPI allreduce on a tf.Tensor or tf.IndexedSlices. + + Arguments: + tensor: tf.Tensor, tf.Variable, or tf.IndexedSlices to reduce. + The shape of the input must be identical across all ranks. + average: If True, computes the average over all ranks. + Otherwise, computes the sum over all ranks. + + This function performs a bandwidth-optimal ring allreduce on the input + tensor. If the input is an tf.IndexedSlices, the function instead does an + allgather on the values and the indices, effectively doing an allreduce on + the represented tensor. + """ + if isinstance(tensor, tf.IndexedSlices): + # For IndexedSlices, do two allgathers intead of an allreduce. + mpi_size = tf.cast(size(), tensor.values.dtype) + values = allgather(tensor.values) + indices = allgather(tensor.indices) + + # To make this operation into an average, divide all gathered values by + # the MPI size. + new_values = tf.div(values, mpi_size) if average else values + return tf.IndexedSlices(new_values, indices, + dense_shape=tensor.dense_shape) + else: + mpi_size = tf.cast(size(), tensor.dtype) + summed_tensor = _allreduce(tensor) + new_tensor = (tf.div(summed_tensor, mpi_size) + if average else summed_tensor) + return new_tensor + + +class DistributedOptimizer(tf.train.Optimizer): + """An optimizer that wraps another tf.Optimizer, using an MPI allreduce to + average gradient values before applying gradients to model weights.""" + + def __init__(self, optimizer, name=None, use_locking=False): + """Construct a new DistributedOptimizer, which uses another optimizer + under the hood for computing single-process gradient values and + applying gradient updates after the gradient values have been averaged + across all the MPI ranks. + + Args: + optimizer: + Optimizer to use for computing gradients and applying updates. + name: + Optional name prefix for the operations created when applying + gradients. Defaults to "Distributed" followed by the provided + optimizer type. + use_locking: + Whether to use locking when updating variables. + See Optimizer.__init__ for more info. + """ + if name is None: + name = "Distributed{}".format(type(optimizer).__name__) + + self._optimizer = optimizer + super(DistributedOptimizer, self).__init__( + name=name, use_locking=use_locking) + + def compute_gradients(self, *args, **kwargs): + """Compute gradients of all trainable variables. + + See Optimizer.compute_gradients() for more info. + + In DistributedOptimizer, compute_gradients() is overriden to also + allreduce the gradients before returning them. + """ + gradients = (super(DistributedOptimizer, self) + .compute_gradients(*args, **kwargs)) + return [(allreduce(gradient), var) for (gradient, var) in gradients] + + def _apply_dense(self, *args, **kwargs): + """Calls this same method on the underlying optimizer.""" + return self._optimizer._apply_dense(*args, **kwargs) + + def _apply_sparse(self, *args, **kwargs): + """Calls this same method on the underlying optimizer.""" + return self._optimizer._apply_sparse(*args, **kwargs) + + def _prepare(self, *args, **kwargs): + """Calls this same method on the underlying optimizer.""" + return self._optimizer._prepare(*args, **kwargs) + + def _create_slots(self, *args, **kwargs): + """Calls this same method on the underlying optimizer.""" + return self._optimizer._create_slots(*args, **kwargs) + + def _valid_dtypes(self, *args, **kwargs): + """Calls this same method on the underlying optimizer.""" + return self._optimizer._valid_dtypes(*args, **kwargs) + + def _finish(self, *args, **kwargs): + """Calls this same method on the underlying optimizer.""" + return self._optimizer._finish(*args, **kwargs) + + +class Session(tf.Session): + """A class for running TensorFlow operations, with copies of the same graph + running distributed across different MPI nodes. + + The primary difference between `tf.Session` and `tf.contrib.mpi.Session` is + that the MPI `Session` ensures that the `Session` options are correct for + use with `tf.contrib.mpi`, and initializes MPI immediately upon the start + of the session. + """ + + def __init__(self, gpu=None, target='', graph=None, config=None): + """Creates a new TensorFlow MPI session. + + Unlike a normal `tf.Session`, an MPI Session may only use a single GPU, + which must be specified in advance before the session is initialized. + In addition, it only uses a single graph evaluation thread, and + initializes MPI immediately upon starting. + + If no `graph` argument is specified when constructing the session, + the default graph will be launched in the session. If you are + using more than one graph (created with `tf.Graph()` in the same + process, you will have to use different sessions for each graph, + but each graph can be used in multiple sessions. In this case, it + is often clearer to pass the graph to be launched explicitly to + the session constructor. + + Args: + gpu: (Optional.) The GPU index to use, or None for CPU only MPI. + graph: (Optional.) The `Graph` to be launched (described above). + config: (Optional.) A `ConfigProto` protocol buffer with configuration + options for the session. + """ + if config is None: + config = tf.ConfigProto() + config.inter_op_parallelism_threads = 1 + elif config.inter_op_parallelism_threads != 1: + raise ValueError( + "inter_op_parallelism_threads must be 1 for MPI") + else: + config.inter_op_parallelism_threads = 1 + + if gpu is None: + config.gpu_options.visible_device_list = "" + else: + config.gpu_options.visible_device_list = str(gpu) + + super(Session, self).__init__(target, graph, config=config) + + # Initialize MPI on the relevant device. + self.run(init()) diff --git a/tensorflow/contrib/mpi/mpi_message.proto b/tensorflow/contrib/mpi/mpi_message.proto new file mode 100644 index 000000000..7c7866481 --- /dev/null +++ b/tensorflow/contrib/mpi/mpi_message.proto @@ -0,0 +1,66 @@ +syntax = "proto3"; + +package tensorflow.contrib.mpi; + +// We would like to just use DataType here, but since this +// is a contrib package, linking directly to TensorFlow protos seems to be +// impossible. Doing so compiles, but fails with a cryptic error at runtime +// about a pointer that was passed to free() but not created by malloc(). +// +// Since using the tensorflow/core protos seems to cause issues, we use our own, +// which also has the benefit of supporting only the data types we want to support. +enum MPIDataType { + TF_MPI_FLOAT32 = 0; + TF_MPI_INT32 = 1; +}; + +// An MPIRequest is a message sent from a rank greater than zero to the +// coordinator (rank zero), informing the coordinator of an operation that +// the rank wants to do and the tensor that it wants to apply the operation to. +message MPIRequest { + enum RequestType { + ALLREDUCE = 0; + ALLGATHER = 1; + } + + // The request rank is necessary to create a consistent ordering of results, + // for example in the allgather where the order of outputs should be sorted + // by rank. + int32 request_rank = 1; + RequestType request_type = 2; + MPIDataType tensor_type = 3; + string tensor_name = 4; + + // We use a repeated integer instead of a TensorShapeProto because linking directly + // to TensorFlow protos causes issues. See the comment for MPIDataType. + repeated int64 tensor_shape = 5; +}; + +// An MPIResponse is a message sent from the coordinator (rank zero) to a rank +// greater than zero, informing the rank of an operation should be performed +// now. If the operation requested would result in an error (for example, due +// to a type or shape mismatch), then the MPIResponse can contain an error and +// an error message instead. Finally, an MPIResponse can be a DONE message (if +// there are no more tensors to reduce on this tick of the background loop) or +// SHUTDOWN if all MPI processes should shut down. +message MPIResponse { + enum ResponseType { + ALLREDUCE = 0; + ALLGATHER = 1; + ERROR = 2; + DONE = 3; + SHUTDOWN = 4; + } + + // Empty if the type is DONE or SHUTDOWN. + ResponseType response_type = 1; + string tensor_name = 2; + + // Empty unless response_type is ERROR. + string error_message = 3; + + // Empty unless response_type is ALLGATHER. + // These tensor sizes are the dimension zero sizes of all the input matrices, + // indexed by the rank. + repeated int64 tensor_sizes = 4; +}; diff --git a/tensorflow/contrib/mpi/mpi_ops.cc b/tensorflow/contrib/mpi/mpi_ops.cc new file mode 100644 index 000000000..cc249821e --- /dev/null +++ b/tensorflow/contrib/mpi/mpi_ops.cc @@ -0,0 +1,1147 @@ +// Copyright 2016 The TensorFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= + +#include +#include +#include + +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/shape_inference.h" + +#define EIGEN_USE_THREADS + +#if GOOGLE_CUDA +#include "tensorflow/stream_executor/stream.h" +#include +#endif + +#include "tensorflow/stream_executor/lib/statusor.h" + + +#define OMPI_SKIP_MPICXX +#include "third_party/mpi/mpi.h" +#include "tensorflow/contrib/mpi/ring.h" +#include "tensorflow/contrib/mpi/mpi_message.pb.h" + +/* + * MPI Allreduce and Allgather Ops for TensorFlow. + * + * TensorFlow natively provides inter-device communication through send and + * receive ops and inter-node communication through Distributed TensorFlow, + * based on the same send and receive abstractions. These end up being + * insufficient for synchronous data-parallel training on HPC clusters where + * Infiniband or other high-speed interconnects are available. This module + * implements MPI ops for allgather and allreduce, which do bandwidth-optimal + * gathers and reductions and can take advantage of hardware-optimized + * communication libraries through the MPI implementation. + * + * The primary logic of the allreduce and allgather are in RingAllgather() and + * RingAllreduce(). The background thread which facilitates MPI operations is + * run in BackgroundThreadLoop(). The provided MPI ops are: + * – MPIInit: + * Initialize MPI on a given device (CPU or GPU). + * Should only be run on a single device in every process. + * – MPISize: + * Get the number of MPI processes in the global communicator. + * – MPIRank: + * Get the rank of the current MPI process in the global communicator. + * – MPILocalRank: + * Get the local rank of the current MPI process within its node. + * – MPIAllreduce: + * Perform an allreduce on a Tensor, returning the sum + * across all MPI processes in the global communicator. + * – MPIAllgather: + * Perform an allgather on a Tensor, returning the concatenation of + * the tensor on the first dimension across all MPI processes in the + * global communicator. + * + */ + +template +using StatusOr = perftools::gputools::port::StatusOr; + +using CPUDevice = Eigen::ThreadPoolDevice; +using GPUDevice = Eigen::GpuDevice; + +namespace tensorflow { +namespace contrib { +namespace mpi { + +// Make sure template specializations are generated in the ring.cu.cc and the +// ring.cc file, not in this file. +extern template Status RingAllreduce(OpKernelContext*, Tensor&, Tensor*); +extern template Status RingAllreduce(OpKernelContext*, Tensor&, Tensor*); +extern template Status RingAllgather( + OpKernelContext*, Tensor&, Tensor*, std::vector&); +extern template Status RingAllgather( + OpKernelContext*, Tensor&, Tensor*, std::vector&); +extern template Status RingAllreduce(OpKernelContext*, Tensor&, Tensor*); +extern template Status RingAllreduce(OpKernelContext*, Tensor&, Tensor*); +extern template Status RingAllgather( + OpKernelContext*, Tensor&, Tensor*, std::vector&); +extern template Status RingAllgather( + OpKernelContext*, Tensor&, Tensor*, std::vector&); + +namespace { + +// Return true if the templated type is GPUDevice, otherwise false. +template bool IsGPUDevice(); +template<> bool IsGPUDevice() { return true; }; +template<> bool IsGPUDevice() { return false; }; + +// A callback to call after the MPI communication completes. Since the +// allreduce and allgather ops are asynchronous, this callback is what resumes +// computation after the reduction is completed. +typedef std::function)> CommunicationDoneCallback; + +// Table storing Tensors to be reduced, keyed by unique name. +// This table contains everything necessary to do the reduction: +// - Tensor: The tensor data. +// - OpKernelContext*: A context used to allocate the output or temporary values. +// - CommunicationDoneCallback: A callback to call with the result. +typedef std::unordered_map > TensorTable; + +// Table for storing Tensor metadata on rank zero. This is used for error +// checking and size calculations, as well as determining when a reduction is +// ready to be done (when all nodes are ready to do it). +typedef std::unordered_map > MessageTable; + +// The global state required for the MPI ops. +// +// MPI is a library that stores a lot of global per-program state and often +// requires running on a single thread. As a result, we have to have a single +// background thread responsible for all MPI operations, and communicate with +// that background thread through global state. +struct MPIGlobalState { + // An atomic boolean which is set to true when MPI is initialized. + // This ensures that MPI_Init is never called twice. + std::atomic_flag initialized_flag = ATOMIC_FLAG_INIT; + + // A mutex that needs to be used whenever MPI operations are done. + std::mutex mutex; + + // Tensors waiting to be allreduced or allgathered. + TensorTable tensor_table; + + // Queue of MPI requests waiting to be sent to the coordinator node. + std::queue message_queue; + + // Background thread running MPI communication. + std::thread background_thread; + + // Whether the background thread should shutdown. + bool shut_down = false; + + // Only exists on the coordinator node (rank zero). Maintains a count of + // how many nodes are ready to allreduce every tensor (keyed by tensor + // name). + std::unique_ptr message_table; + + // Whether MPI_Init has been completed on the background thread. + bool initialization_done = false; + + // The device that MPI was initialized on. (-1 for no GPU) + int device = -1; + + // Whether MPI_Init succeeded on the background thread. + Status init_status; + + // The MPI rank, local rank, and size. + int rank = 0; + int local_rank = 0; + int size = 1; + + // The CUDA stream used for data transfers and within-allreduce operations. + // A naive implementation would use the TensorFlow StreamExecutor CUDA + // stream. However, the allreduce and allgather require doing memory copies + // and kernel executions (for accumulation of values on the GPU). However, + // the subsequent operations must wait for those operations to complete, + // otherwise MPI (which uses its own stream internally) will begin the data + // transfers before the CUDA calls are complete. In order to wait for those + // CUDA operations, if we were using the TensorFlow stream, we would have to + // synchronize that stream; however, other TensorFlow threads may be + // submitting more work to that stream, so synchronizing on it can cause the + // allreduce to be delayed, waiting for compute totally unrelated to it in + // other parts of the graph. Overlaying memory transfers and compute during + // backpropagation is crucial for good performance, so we cannot use the + // TensorFlow stream, and must use our own stream. +#if GOOGLE_CUDA + cudaStream_t stream; + std::atomic_flag stream_created_flag = ATOMIC_FLAG_INIT; +#endif + + ~MPIGlobalState() { + // Make sure that the destructor of the background thread is safe to + // call. If a thread is still joinable (not detached or complete) its + // destructor cannot be called. + if(background_thread.joinable()) { + shut_down = true; + background_thread.join(); + } + } +}; + +// All the MPI state that must be stored globally per-process. +static MPIGlobalState mpi_global; + +// For clarify in argument lists. +#define RANK_ZERO 0 + +// A tag used for all coordinator messaging. +#define TAG_NOTIFY 1 + +// Store the MPIRequest for a name, and return whether the total count of +// MPIRequests for that tensor is now equal to the MPI size (and thus we are +// ready to reduce the tensor). +bool IncrementTensorCount( + std::unique_ptr& message_table, + MPIRequest msg, int mpi_size) { + auto name = msg.tensor_name(); + auto table_iter = message_table->find(name); + if(table_iter == message_table->end()) { + message_table->emplace(name, std::vector({msg})); + table_iter = message_table->find(name); + } else { + table_iter->second.push_back(msg); + } + + int count = table_iter->second.size(); + return count == mpi_size; +} + +// Once a tensor is ready to be reduced, the coordinator sends an MPIResponse +// instructing all ranks to start the reduction to all ranks. The MPIResponse +// also contains error messages in case the submitted MPIRequests were not +// valid (for example, contained mismatched shapes or types). +// +// Constructing the MPIResponse, thus, requires a whole lot of error checking. +MPIResponse ConstructMPIResponse(std::unique_ptr& message_table, std::string name) { + bool error = false; + auto it = message_table->find(name); + assert(it != message_table->end()); + + std::vector requests = it->second; + assert(requests.size() > 0); + + std::ostringstream error_message_stream; + + // Check that all data types being reduced or gathered are identical + auto data_type = requests[0].tensor_type(); + for(unsigned int i = 1; i < requests.size(); i++) { + auto request_type = requests[i].tensor_type(); + if(data_type != request_type) { + error = true; + error_message_stream + << "Mismatched data types: One rank had type " + << MPIDataType_Name(data_type) + << ", but another rank had type " + << MPIDataType_Name(request_type) + << "."; + break; + } + } + + // Check that all requested operations are the same + auto message_type = requests[0].request_type(); + for(unsigned int i = 1; i < requests.size(); i++) { + if(error) { + break; + } + + auto request_type = requests[i].request_type(); + if(message_type != request_type) { + error = true; + error_message_stream + << "Mismatched MPI operations: One rank did an " + << message_type + << ", but another rank did an " + << request_type + << "."; + break; + } + } + + // If we are doing an allreduce, check that all tensor shapes are identical + if(message_type == MPIRequest::ALLREDUCE) { + TensorShape tensor_shape; + for(auto it = requests[0].tensor_shape().begin(); + it != requests[0].tensor_shape().end(); it++) { + tensor_shape.AddDim(*it); + } + for(unsigned int i = 1; i < requests.size(); i++) { + if(error) { + break; + } + + TensorShape request_shape; + for(auto it = requests[i].tensor_shape().begin(); + it != requests[i].tensor_shape().end(); it++) { + request_shape.AddDim(*it); + } + if(tensor_shape != request_shape) { + error = true; + error_message_stream + << "Mismatched allreduce tensor shapes: " + << "One rank reduced a tensor of shape " + << tensor_shape.DebugString() + << ", but another rank sent a tensor of shape " + << request_shape.DebugString() + << "."; + break; + } + } + } + + // If we are doing an allgather, make sure all but the first dimension are + // the same. The first dimension may be different and the output tensor is + // the sum of the first dimension. Collect the sizes by rank. + std::vector tensor_sizes(requests.size()); + if(message_type == MPIRequest::ALLGATHER) { + TensorShape tensor_shape; + for(auto it = requests[0].tensor_shape().begin(); + it != requests[0].tensor_shape().end(); it++) { + tensor_shape.AddDim(*it); + } + + if(tensor_shape.dims() == 0) { + error = true; + error_message_stream + << "Rank zero tried to gather a rank-zero tensor."; + } else { + tensor_sizes[requests[0].request_rank()] = size_t(tensor_shape.dim_size(0)); + } + + for(unsigned int i = 1; i < requests.size(); i++) { + if(error) { + break; + } + + TensorShape request_shape; + for(auto it = requests[i].tensor_shape().begin(); + it != requests[i].tensor_shape().end(); it++) { + request_shape.AddDim(*it); + } + if(tensor_shape.dims() != request_shape.dims()) { + error = true; + error_message_stream + << "Mismatched allgather tensor shapes: " + << "One rank gathered a tensor of rank " + << tensor_shape.dims() + << ", but another rank sent a tensor of rank " + << request_shape.dims() + << "."; + break; + } + + bool dim_mismatch = false; + for(unsigned int dim = 1; dim < tensor_shape.dims(); dim++) { + if(tensor_shape.dim_size(dim) != request_shape.dim_size(dim)) { + error = true; + error_message_stream + << "Mismatched allgather tensor shapes: " + << "One rank gathered a tensor with dimension " + << dim << " equal to " << tensor_shape.dim_size(dim) + << ", but another rank sent a tensor with dimension " + << dim << " equal to " << request_shape.dim_size(dim) + << "."; + dim_mismatch = true; + break; + } + } + if(dim_mismatch) { + break; + } + + tensor_sizes[requests[i].request_rank()] = size_t(request_shape.dim_size(0)); + } + } + + MPIResponse response; + response.set_tensor_name(name); + if(error) { + std::string error_message = error_message_stream.str(); + response.set_response_type(MPIResponse::ERROR); + response.set_error_message(error_message); + } else if(message_type == MPIRequest::ALLGATHER) { + response.set_response_type(MPIResponse::ALLGATHER); + for(auto dim : tensor_sizes) { + response.add_tensor_sizes(dim); + } + } else if(message_type == MPIRequest::ALLREDUCE) { + response.set_response_type(MPIResponse::ALLREDUCE); + } + + // Clear all queued up requests for this name. They are now taken care of + // by the constructed MPI response. + message_table->erase(it); + + return response; +} + +// Process an MPIResponse by doing a reduction, a gather, or raising an error. +void PerformReductionOrGather(TensorTable& tensor_table, MPIResponse response) { + Tensor tensor; + OpKernelContext* context; + CommunicationDoneCallback callback; + bool on_gpu; + { + // Lock on the tensor table. + std::lock_guard guard(mpi_global.mutex); + + // We should never fail at finding this key in the tensor table. + auto name = response.tensor_name(); + auto iter = tensor_table.find(name); + assert(iter != tensor_table.end()); + + assert(response.response_type() == MPIResponse::ALLREDUCE || + response.response_type() == MPIResponse::ALLGATHER || + response.response_type() == MPIResponse::ERROR); + + std::tie(tensor, context, on_gpu, callback) = iter->second; + + // Clear the tensor table of this tensor and its callbacks; the rest of + // this function takes care of it. + tensor_table.erase(iter); + } + + // Use CPUDevice instead of GPUDevice if no CUDA, to ensure we don't + // link to non-existent symbols. +#if GOOGLE_CUDA +#define GPU_DEVICE_IF_CUDA GPUDevice +#else +#define GPU_DEVICE_IF_CUDA CPUDevice +#endif + + Tensor output; + Status status; + if(response.response_type() == MPIResponse::ALLGATHER) { + // Copy tensor sizes from the MPI response into a vector of size_t + std::vector tensor_sizes; + for(auto it = response.tensor_sizes().begin(); + it != response.tensor_sizes().end(); it++) { + tensor_sizes.push_back(size_t(*it)); + } + + if(tensor.dtype() == DT_FLOAT) { + status = on_gpu ? RingAllgather(context, tensor, &output, tensor_sizes) + : RingAllgather(context, tensor, &output, tensor_sizes); + } else if(tensor.dtype() == DT_INT32) { + status = on_gpu ? RingAllgather(context, tensor, &output, tensor_sizes) + : RingAllgather(context, tensor, &output, tensor_sizes); + } else { + status = errors::Unknown("Invalid tensor type for MPI allgather."); + } + } else if(response.response_type() == MPIResponse::ALLREDUCE) { + if(tensor.dtype() == DT_FLOAT) { + status = on_gpu ? RingAllreduce(context, tensor, &output) + : RingAllreduce(context, tensor, &output); + } else if(tensor.dtype() == DT_INT32) { + status = on_gpu ? RingAllreduce(context, tensor, &output) + : RingAllreduce(context, tensor, &output); + } else { + status = errors::Unknown("Invalid tensor type for MPI allreduce."); + } + } else if(response.response_type() == MPIResponse::ERROR) { + status = errors::FailedPrecondition(response.error_message()); + } + + if(status.ok()) { + callback(StatusOr(output)); + } else { + callback(StatusOr(status)); + } +} + +// The MPI background thread loop coordinates all the MPI processes and the +// tensor reductions. The design of the communicator mechanism is limited by a few considerations: +// +// 1. Some MPI implementations require all MPI calls to happen from a single thread. +// Since TensorFlow may use several threads for graph processing, this means we must have +// our own dedicated thread for dealing with MPI. +// 2. We want to gracefully handle errors, when MPI processes do not properly agree upon +// what should happen (such as mismatched types or shapes). To do so requires the MPI processes +// to know about the shapes and types of the relevant tensors on the other processes. +// 3. The MPI reductions and gathers should be able to happen in parallel +// with other ongoing operations. This means that they cannot be blocking +// ops, but rather must be async ops, the execution of which happens on a +// separate thread. +// 4. We cannot guarantee that all the MPI processes reduce their tensors +// in the same order, so we cannot dispatch one thread per tensor, +// otherwise we may end up dispatching many blocked threads and never make +// progress if we have a thread pool limit. +// +// The coordinator currently follows a master-worker paradigm. Rank zero acts +// as the master (the "coordinator"), whereas all other ranks are simply +// workers. Each rank runs its own background thread which progresses in ticks. +// In each tick, the following actions happen: +// +// a) The workers send an MPIRequest to the coordinator, indicating what +// they would like to do (which tensor they would like to gather and +// reduce, as well as their shape and type). They repeat this for every +// tensor that they would like to operate on. +// +// b) The workers send an empty "DONE" message to the coordinator to +// indicate that there are no more tensors they wish to operate on. +// +// c) The coordinator receives the MPIRequests from the workers, as well +// as from its own TensorFlow ops, and stores them in a request table. The +// coordinator continues to receive MPIRequest messages until it has +// received MPI_SIZE number of empty "DONE" messages. +// +// d) The coordinator finds all tensors that are ready to be reduced, +// gathered, or all operations that result in an error. For each of those, +// it sends an MPIResponse to all the workers. When no more MPIResponses +// are available, it sends a "DONE" response to the workers. If the process +// is being shutdown, it instead sends a "SHUTDOWN" response. +// +// e) The workers listen for MPIResponse messages, processing each one by +// doing the required reduce or gather, until they receive a "DONE" +// response from the coordinator. At that point, the tick ends. +// If instead of "DONE" they receive "SHUTDOWN", they exit their background loop. +void BackgroundThreadLoop(MPIGlobalState& state) { +#if GOOGLE_CUDA + // Set the device, so that this thread uses the same GPU context as the + // calling thread. + if(state.device > 0) { + cudaSetDevice(state.device); + cudaStreamCreate(&state.stream); + } +#endif + + // Initialize MPI. This must happen on the background thread, since not all + // MPI implementations support being called from multiple threads. + auto init_result = MPI_Init(NULL, NULL); + if(init_result != MPI_SUCCESS) { + state.init_status = errors::Unknown("Could not initialize MPI; MPI_Init() failed."); + } else { + state.init_status = Status::OK(); + } + + // Get MPI rank to determine if we are rank zero. + int rank; + MPI_Comm_rank(MPI_COMM_WORLD, &rank); + bool is_coordinator = rank == 0; + + // Get MPI size to determine how many tensors to wait for before reducing. + int size; + MPI_Comm_size(MPI_COMM_WORLD, &size); + + // Determine local rank by querying the local communicator. + MPI_Comm local_comm; + MPI_Comm_split_type(MPI_COMM_WORLD, MPI_COMM_TYPE_SHARED, 0, + MPI_INFO_NULL, &local_comm); + int local_rank; + MPI_Comm_rank(local_comm, &local_rank); + + state.rank = rank; + state.local_rank = local_rank; + state.size = size; + state.initialization_done = true; + + // Initialize the tensor count table. No tensors are available yet. + if(is_coordinator) { + state.message_table = + std::unique_ptr(new MessageTable()); + } + + // The coordinator sends a SHUTDOWN message to trigger shutdown. + bool should_shut_down = false; + do { + // This delay determines thread frequency and MPI message latency + std::this_thread::sleep_for(std::chrono::milliseconds(5)); + + // Copy the data structures from global state under this lock. + // However, don't keep the lock for the rest of the loop, so that + // enqueued stream callbacks can continue. + std::queue message_queue; + { + std::lock_guard guard(state.mutex); + while(!state.message_queue.empty()) { + MPIRequest message = state.message_queue.front(); + state.message_queue.pop(); + message_queue.push(message); + } + } + + // Collect all tensors that are ready to be reduced. Record them in the + // tensor count table (rank zero) or send them to rank zero to be + // recorded (everyone else). + std::vector ready_to_reduce; + while(!message_queue.empty()) { + // Pop the first available message message + MPIRequest message = message_queue.front(); + message_queue.pop(); + + if(is_coordinator) { + bool reduce = IncrementTensorCount(state.message_table, message, size); + if(reduce) { + ready_to_reduce.push_back(message.tensor_name()); + } + } else { + std::string encoded_message; + message.SerializeToString(&encoded_message); + MPI_Send(encoded_message.c_str(), encoded_message.length() + 1, + MPI_BYTE, RANK_ZERO, TAG_NOTIFY, MPI_COMM_WORLD); + } + } + + // Rank zero has put all its own tensors in the tensor count table. + // Now, it should count all the tensors that are coming from other + // ranks at this tick. It should keep getting tensors until it gets a + // DONE message from all the other ranks. + if(is_coordinator) { + // Count of DONE messages. Keep receiving messages until the number + // of messages is equal to the number of processes. Initialize to + // one since the coordinator is effectively done. + int completed_ranks = 1; + while(completed_ranks != size) { + MPI_Status status; + MPI_Probe(MPI_ANY_SOURCE, TAG_NOTIFY, MPI_COMM_WORLD, &status); + + // Find number of characters in message (including zero byte). + int source_rank = status.MPI_SOURCE; + int msg_length; + MPI_Get_count(&status, MPI_BYTE, &msg_length); + + // If the length is zero, this is a DONE message. + if(msg_length == 0) { + completed_ranks++; + MPI_Recv(NULL, 0, MPI_BYTE, source_rank, TAG_NOTIFY, MPI_COMM_WORLD, &status); + continue; + } + + // Get tensor name from MPI into an std::string. + char* buffer = new char[msg_length]; + MPI_Recv(buffer, msg_length, MPI_BYTE, source_rank, + TAG_NOTIFY, MPI_COMM_WORLD, &status); + std::string received_data(buffer); + delete[] buffer; + + MPIRequest received_message; + received_message.ParseFromString(received_data); + auto received_name = received_message.tensor_name(); + + bool reduce = IncrementTensorCount( + state.message_table, received_message, size); + if(reduce) { + ready_to_reduce.push_back(received_name); + } + } + + // At this point, rank zero should have a fully updated tensor count + // table and should know all the tensors that need to be reduced or + // gathered, and everyone else should have sent all their information + // to rank zero. We can now do reductions and gathers; rank zero will + // choose which ones and in what order, and will notify the other ranks + // before doing each reduction. + for(int i = 0; i < ready_to_reduce.size(); i++) { + // Notify all nodes which tensor we'd like to reduce at this step. + auto name = ready_to_reduce[i]; + MPIResponse response = ConstructMPIResponse(state.message_table, name); + + std::string encoded_response; + response.SerializeToString(&encoded_response); + for(int r = 1; r < size; r++) { + MPI_Send(encoded_response.c_str(), encoded_response.length() + 1, + MPI_BYTE, r, TAG_NOTIFY, MPI_COMM_WORLD); + } + + // Perform the reduction. All nodes should end up performing the same reduction. + PerformReductionOrGather(state.tensor_table, response); + } + + // Notify all nodes that we are done with the reductions for this tick. + MPIResponse done_response; + should_shut_down = state.shut_down; + done_response.set_response_type( + should_shut_down ? MPIResponse::SHUTDOWN : MPIResponse::DONE); + std::string encoded_response; + done_response.SerializeToString(&encoded_response); + for(int r = 1; r < size; r++) { + MPI_Send(encoded_response.c_str(), encoded_response.length() + 1, + MPI_BYTE, r, TAG_NOTIFY, MPI_COMM_WORLD); + } + } else { + // Notify the coordinator that this node is done sending messages. + // A DONE message is encoded as a zero-length message. + MPI_Send(NULL, 0, MPI_BYTE, RANK_ZERO, TAG_NOTIFY, MPI_COMM_WORLD); + + // Receive names for tensors to reduce from rank zero. + // Once we receive a empty DONE message, stop waiting for more names. + while(true) { + MPI_Status status; + MPI_Probe(0, TAG_NOTIFY, MPI_COMM_WORLD, &status); + + // Find number of characters in message (including zero byte). + int msg_length; + MPI_Get_count(&status, MPI_BYTE, &msg_length); + + // Get tensor name from MPI into an std::string. + char* buffer = new char[msg_length]; + MPI_Recv(buffer, msg_length, MPI_BYTE, 0, + TAG_NOTIFY, MPI_COMM_WORLD, &status); + std::string received_message(buffer); + delete[] buffer; + + MPIResponse response; + response.ParseFromString(received_message); + if(response.response_type() == MPIResponse::DONE) { + // No more messages this tick + break; + } else if(response.response_type() == MPIResponse::SHUTDOWN) { + // No more messages this tick, and the background thread should shut down + should_shut_down = true; + break; + } else { + // Process the current message + PerformReductionOrGather(state.tensor_table, response); + } + } + } + } while(!should_shut_down); + + MPI_Finalize(); +} + +// Initialize MPI and start the MPI background thread. Ensure that this is +// only done once no matter how many times this function is called. +Status InitializeMPIOnce(bool gpu) { + // Ensure MPI is only initialized once. + if(mpi_global.initialized_flag.test_and_set()) + return mpi_global.init_status; + + int current_device = -1; +#if GOOGLE_CUDA + if(gpu) { + cudaGetDevice(¤t_device); + } +#endif + mpi_global.device = current_device; + + // Start the MPI background thread, which assumes MPI is initialized + mpi_global.background_thread = std::thread(BackgroundThreadLoop, + std::ref(mpi_global)); + + // Wait to ensure that the background thread has finished initializing MPI. + while(!mpi_global.initialization_done) { + std::this_thread::sleep_for(std::chrono::milliseconds(1)); + } + + return mpi_global.init_status; +} + +// Check that MPI is initialized and is initialized on the same device. +Status InitializedMPIOnSameDevice(bool gpu) { + if(!mpi_global.initialization_done) { + return errors::FailedPrecondition("MPI has not been initialized; use tf.contrib.mpi.Session."); + } + if(gpu && mpi_global.device < 0) { + return errors::FailedPrecondition("MPI op on GPU, but initialized on CPU."); + } + +#if GOOGLE_CUDA + if(gpu) { + int current_device = -1; + cudaGetDevice(¤t_device); + if(current_device != mpi_global.device) { + return errors::FailedPrecondition("MPI op on different GPU than MPI was initialized on."); + } + } +#endif + + return Status::OK(); +} + +// Convert a TensorFlow DataType to our MPIDataType. +Status DataTypeToMPIType(DataType tf_dtype, MPIDataType* mpi_dtype) { + if(tf_dtype == DT_FLOAT) { + *mpi_dtype = TF_MPI_FLOAT32; + } else if(tf_dtype == DT_INT32) { + *mpi_dtype = TF_MPI_INT32; + } else { + return errors::FailedPrecondition("Invalid tensor type passed."); + } + return Status::OK(); +} + +// MPI must be initialized and the background thread must be running before +// this function is called. +void EnqueueTensorAllreduce( + OpKernelContext* context, + const Tensor& tensor, + const std::string name, + const bool on_gpu, + CommunicationDoneCallback callback) { + MPIDataType dtype; + Status status = DataTypeToMPIType(tensor.dtype(), &dtype); + if(!status.ok()) { + callback(StatusOr(status)); + return; + } + + int rank; + MPI_Comm_rank(MPI_COMM_WORLD, &rank); + + MPIRequest message; + message.set_request_rank(rank); + message.set_tensor_name(name); + message.set_tensor_type(dtype); + message.set_request_type(MPIRequest::ALLREDUCE); + for(int i = 0; i < tensor.shape().dims(); i++) { + message.add_tensor_shape(tensor.shape().dim_size(i)); + } + + std::lock_guard guard(mpi_global.mutex); + std::tuple record(tensor, context, on_gpu, callback); + mpi_global.tensor_table.emplace(name, record); + mpi_global.message_queue.push(message); +} + +// MPI must be initialized and the background thread must be running before +// this function is called. +void EnqueueTensorAllgather( + OpKernelContext* context, + const Tensor& tensor, + const std::string name, + const bool on_gpu, + CommunicationDoneCallback callback) { + MPIDataType dtype; + Status status = DataTypeToMPIType(tensor.dtype(), &dtype); + if(!status.ok()) { + callback(StatusOr(status)); + return; + } + + int rank; + MPI_Comm_rank(MPI_COMM_WORLD, &rank); + + MPIRequest message; + message.set_request_rank(rank); + message.set_tensor_name(name); + message.set_tensor_type(dtype); + message.set_request_type(MPIRequest::ALLGATHER); + for(int i = 0; i < tensor.shape().dims(); i++) { + message.add_tensor_shape(tensor.shape().dim_size(i)); + } + + std::lock_guard guard(mpi_global.mutex); + std::tuple record(tensor, context, on_gpu, callback); + mpi_global.tensor_table.emplace(name, record); + mpi_global.message_queue.push(message); +} +} + +#if GOOGLE_CUDA +cudaStream_t CudaStreamForMPI() { + return mpi_global.stream; +} +#endif + +// Op to initialize MPI in the current process. The settings used in the +// configuration are the same that must be used for all future MPI ops. +template +class MPIInitOp : public OpKernel { + public: + explicit MPIInitOp(OpKernelConstruction* context) : OpKernel(context) { + } + + + void Compute(OpKernelContext* context) override { + bool on_gpu = IsGPUDevice(); + OP_REQUIRES_OK(context, InitializeMPIOnce(on_gpu)); + } +}; + +REGISTER_KERNEL_BUILDER(Name("MPIInit").Device(DEVICE_CPU), MPIInitOp); +#if GOOGLE_CUDA +REGISTER_KERNEL_BUILDER(Name("MPIInit").Device(DEVICE_GPU), MPIInitOp); +#endif + +REGISTER_OP("MPIInit") + .Doc(R"doc( +Initialize MPI for the current process. + +If this is run on a GPU, then that GPU must be used for all future MPI +operations. If it is run on CPU, then all future MPI operations must also run on +CPU. +)doc"); + +// Op to get the current MPI Size. +template +class MPISizeOp : public OpKernel { + public: + explicit MPISizeOp(OpKernelConstruction* context) : OpKernel(context) { } + + + void Compute(OpKernelContext* context) override { + bool on_gpu = IsGPUDevice(); + OP_REQUIRES_OK(context, InitializedMPIOnSameDevice(on_gpu)); + + // Get the number of processes + int world_size = mpi_global.size; + + // Write integer to output tensor + Tensor* output; + OP_REQUIRES_OK(context, context->allocate_output(0, TensorShape({}), &output)); + + auto flat = output->flat(); + flat(0) = world_size; + } +}; + +REGISTER_KERNEL_BUILDER(Name("MPISize").Device(DEVICE_CPU), MPISizeOp); +#if GOOGLE_CUDA +REGISTER_KERNEL_BUILDER(Name("MPISize").Device(DEVICE_GPU).HostMemory("size"), MPISizeOp); +#endif + +REGISTER_OP("MPISize") + .Output("size: int32") + .SetShapeFn([](shape_inference::InferenceContext* c) { + c->set_output(0, c->Scalar()); + return Status::OK(); + }) + .Doc(R"doc( +Returns the number of running MPI processes. + +More precisely, returns the number of MPI processes in the group associated +with the MPI_COMM_WORLD communicator. + +size: Size of the MPI group. +)doc"); + +// Op to get the current MPI Rank. +template +class MPIRankOp : public OpKernel { + public: + explicit MPIRankOp(OpKernelConstruction* context) : OpKernel(context) { } + + void Compute(OpKernelContext* context) override { + bool on_gpu = IsGPUDevice(); + OP_REQUIRES_OK(context, InitializedMPIOnSameDevice(on_gpu)); + + // Get the processor index + int rank = mpi_global.rank; + + // Write integer to output tensor + Tensor* output; + OP_REQUIRES_OK(context, context->allocate_output(0, TensorShape({}), &output)); + + auto flat = output->flat(); + flat(0) = rank; + } +}; + +REGISTER_KERNEL_BUILDER(Name("MPIRank").Device(DEVICE_CPU), MPIRankOp); +#if GOOGLE_CUDA +REGISTER_KERNEL_BUILDER(Name("MPIRank").Device(DEVICE_GPU).HostMemory("rank"), MPIRankOp); +#endif + +REGISTER_OP("MPIRank") + .Output("rank: int32") + .SetShapeFn([](shape_inference::InferenceContext* c) { + c->set_output(0, c->Scalar()); + return Status::OK(); + }) + .Doc(R"doc( +Returns the index of the current process in the MPI group. + +More precisely, returns the rank of the calling process in the MPI_COMM_WORLD +communicator. + +rank: Rank of the calling process. +)doc"); + + +// Op to get the current local MPI Rank. +template +class MPILocalRankOp : public OpKernel { + public: + explicit MPILocalRankOp(OpKernelConstruction* context) : OpKernel(context) { } + + void Compute(OpKernelContext* context) override { + bool on_gpu = IsGPUDevice(); + OP_REQUIRES_OK(context, InitializedMPIOnSameDevice(on_gpu)); + + int local_rank = mpi_global.local_rank; + + // Write integer to output tensor + Tensor* output; + OP_REQUIRES_OK(context, context->allocate_output(0, TensorShape({}), &output)); + + auto flat = output->flat(); + flat(0) = local_rank; + } +}; + +REGISTER_KERNEL_BUILDER(Name("MPILocalRank").Device(DEVICE_CPU), MPILocalRankOp); +#if GOOGLE_CUDA +REGISTER_KERNEL_BUILDER(Name("MPILocalRank").Device(DEVICE_GPU).HostMemory("rank"), MPILocalRankOp); +#endif + +REGISTER_OP("MPILocalRank") + .Output("rank: int32") + .SetShapeFn([](shape_inference::InferenceContext* c) { + c->set_output(0, c->Scalar()); + return Status::OK(); + }) + .Doc(R"doc( +Returns the index of the current process in the node it is on. + +More precisely, returns the rank of the calling process in communicator that +only spans the MPI processes running on that node. + +rank: Rank of the calling process on the node it is on. +)doc"); + +template +class MPIAllreduceOp : public AsyncOpKernel { + public: + explicit MPIAllreduceOp(OpKernelConstruction* context) : AsyncOpKernel(context) { } + + void ComputeAsync(OpKernelContext* context, DoneCallback done) override { + bool on_gpu = IsGPUDevice(); + OP_REQUIRES_OK(context, InitializedMPIOnSameDevice(on_gpu)); + + auto device_context = context->op_device_context(); + auto node_name = name(); + auto callback = [node_name, done, context, on_gpu] { + auto tensor = context->input(0); + EnqueueTensorAllreduce(context, tensor, node_name, on_gpu, + [node_name, done, context](StatusOr status) { + if(status.ok()) { + Tensor output = status.ValueOrDie(); + context->set_output(0, output); + } + context->SetStatus(status.status()); + done(); + }); + }; + + // If we are on a CPU, our device context will be null and we can't + // get a stream to enqueue this on. On a CPU this op is called when the + // data is already available, so we can just immediately do the allreduce; + // we don't have to wait for the data to get populated. +#if GOOGLE_CUDA + if(device_context == nullptr) { + callback(); + } else { + auto stream = device_context->stream(); + stream->ThenDoHostCallback(callback); + } +#else + callback(); +#endif + } +}; + +REGISTER_KERNEL_BUILDER(Name("MPIAllreduce").Device(DEVICE_CPU), MPIAllreduceOp); +#if GOOGLE_CUDA +REGISTER_KERNEL_BUILDER(Name("MPIAllreduce").Device(DEVICE_GPU), MPIAllreduceOp); +#endif + +REGISTER_OP("MPIAllreduce") + .Attr("T: {int32, float32}") + .Input("tensor: T") + .Output("sum: T") + .SetShapeFn([](shape_inference::InferenceContext* c) { + c->set_output(0, c->input(0)); + return Status::OK(); + }) + .Doc(R"doc( +Perform an MPI Allreduce on a tensor. All other processes that do a reduction +on a tensor with the same name must have the same dimension for that tensor. +Tensors are reduced with other tensors that have the same node name for the +allreduce. + +Arguments + tensor: A tensor to reduce. + +Output + sum: A tensor with the same shape as `tensor`, summed across all MPI processes. +)doc"); + +template +class MPIAllgatherOp : public AsyncOpKernel { + public: + explicit MPIAllgatherOp(OpKernelConstruction* context) : AsyncOpKernel(context) { + } + + void ComputeAsync(OpKernelContext* context, DoneCallback done) override { + bool on_gpu = IsGPUDevice(); + OP_REQUIRES_OK(context, InitializedMPIOnSameDevice(on_gpu)); + + auto device_context = context->op_device_context(); + auto node_name = name(); + auto callback = [node_name, done, context, on_gpu] { + auto tensor = context->input(0); + EnqueueTensorAllgather(context, tensor, node_name, on_gpu, + [node_name, done, context](StatusOr status) { + if(status.ok()) { + Tensor output = status.ValueOrDie(); + context->set_output(0, output); + } + context->SetStatus(status.status()); + done(); + }); + }; + + // If we are on a CPU, our device context will be null and we can't + // get a stream to enqueue this on. On a CPU this op is called when the + // data is already available, so we can just immediately do the allgather; + // we don't have to wait for the data to get populated. +#if GOOGLE_CUDA + if(device_context == nullptr) { + callback(); + } else { + auto stream = device_context->stream(); + stream->ThenDoHostCallback(callback); + } +#else + callback(); +#endif + } +}; + +REGISTER_KERNEL_BUILDER(Name("MPIAllgather").Device(DEVICE_CPU), MPIAllgatherOp); +#if GOOGLE_CUDA +REGISTER_KERNEL_BUILDER(Name("MPIAllgather").Device(DEVICE_GPU), MPIAllgatherOp); +#endif + +REGISTER_OP("MPIAllgather") + .Attr("T: {int32, float32}") + .Input("tensor: T") + .Output("output: T") + .SetShapeFn([](shape_inference::InferenceContext* c) { + shape_inference::ShapeHandle output; + TF_RETURN_IF_ERROR(c->ReplaceDim(c->input(0), 0, c->UnknownDim(), &output)); + c->set_output(0, output); + return Status::OK(); + }) + .Doc(R"doc( +Perform an MPI Allgather on a tensor. All other processes that do a gather on a +tensor with the same name must have the same rank for that tensor, and have the +same dimension on all but the first dimension. + +Arguments + tensor: A tensor to gather. + +Output + gathered: A tensor with the same shape as `tensor` except for the first dimension. +)doc"); + +} // namespace mpi +} // namespace contrib +} // namespace tensorflow diff --git a/tensorflow/contrib/mpi/mpi_ops.py b/tensorflow/contrib/mpi/mpi_ops.py new file mode 100644 index 000000000..8e7734a14 --- /dev/null +++ b/tensorflow/contrib/mpi/mpi_ops.py @@ -0,0 +1,154 @@ +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================= +"""Inter-process communication using MPI.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.python.framework import errors +from tensorflow.python.framework import load_library +from tensorflow.python.framework import ops +from tensorflow.python.platform import resource_loader +from tensorflow.python.platform import tf_logging as logging + + +def _load_library(name, op_list=None): + """Loads a .so file containing the specified operators. + + Args: + name: The name of the .so file to load. + op_list: A list of names of operators that the library should have. If None + then the .so file's contents will not be verified. + + Raises: + NameError if one of the required ops is missing. + """ + try: + filename = resource_loader.get_path_to_datafile(name) + library = load_library.load_op_library(filename) + for expected_op in (op_list or []): + for lib_op in library.OP_LIST.op: + if lib_op.name == expected_op: + break + else: + raise NameError( + 'Could not find operator %s in dynamic library %s' % + (expected_op, name)) + return library + except errors.NotFoundError: + logging.warning('%s file could not be loaded.', name) + + +MPI_LIB = _load_library('mpi.so', ['MPISize', 'MPIRank', 'MPILocalRank', + 'MPIAllgather', 'MPIAllreduce']) + + +def size(name=None): + """An op which returns the number of MPI processes. + + This is equivalent to running `MPI_Comm_size(MPI_COMM_WORLD, ...)` to get the + size of the global communicator. + + Returns: + An integer scalar containing the number of MPI processes. + """ + return MPI_LIB.mpi_size(name=name) + + +ops.NotDifferentiable('MPISize') + + +def rank(name=None): + """An op which returns the MPI rank of the calling process. + + This is equivalent to running `MPI_Comm_rank(MPI_COMM_WORLD, ...)` to get the + rank of the current process in the global communicator. + + Returns: + An integer scalar with the MPI rank of the calling process. + """ + return MPI_LIB.mpi_rank(name=name) + + +ops.NotDifferentiable('MPIRank') + + +def init(name=None): + """An op which initializes MPI on the device on which it is run. + + All future MPI ops must be run on the same device that the `init` op was run + on. + """ + return MPI_LIB.mpi_init(name=name) + + +ops.NotDifferentiable('MPIInit') + + +def local_rank(name=None): + """An op which returns the local MPI rank of the calling process, within the + node that it is running on. For example, if there are seven processes running + on a node, their local ranks will be zero through six, inclusive. + + This is equivalent to running `MPI_Comm_rank(...)` on a new communicator + which only includes processes on the same node. + + Returns: + An integer scalar with the local MPI rank of the calling process. + """ + return MPI_LIB.mpi_local_rank(name=name) + + +ops.NotDifferentiable('MPILocalRank') + + +def _allreduce(tensor, name=None): + """An op which sums an input tensor over all the MPI processes. + + The reduction operation is keyed by the name of the op. The tensor type and + shape must be the same on all MPI processes for a given name. The reduction + will not start until all processes are ready to send and receive the tensor. + + Returns: + A tensor of the same shape and type as `tensor`, summed across all + processes. + """ + return MPI_LIB.mpi_allreduce(tensor, name=name) + + +ops.NotDifferentiable('MPIAllreduce') + + +def allgather(tensor, name=None): + """An op which concatenates the input tensor with the same input tensor on + all other MPI processes. + + The concatenation is done on the first dimension, so the input tensors on the + different processes must have the same rank and shape, except for the first + dimension, which is allowed to be different. + + Returns: + A tensor of the same type as `tensor`, concatenated on dimension zero + across all processes. The shape is identical to the input shape, except for + the first dimension, which may be greater and is the sum of all first + dimensions of the tensors in different MPI processes. + """ + return MPI_LIB.mpi_allgather(tensor, name=name) + + +ops.NotDifferentiable('MPIAllgather') + + diff --git a/tensorflow/contrib/mpi/mpi_ops_test.py b/tensorflow/contrib/mpi/mpi_ops_test.py new file mode 100644 index 000000000..84953efc2 --- /dev/null +++ b/tensorflow/contrib/mpi/mpi_ops_test.py @@ -0,0 +1,301 @@ +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================= + +"""Tests for tensorflow.contrib.mpi.mpi_ops.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os.path +import itertools + +import tensorflow as tf + +import tensorflow.contrib.mpi as mpi + + +def mpi_env_rank_and_size(): + """Get MPI rank and size from environment variables and return them as a + tuple of integers. + + Most MPI implementations have an `mpirun` or `mpiexec` command that will + run an MPI executable and set up all communication necessary between the + different processors. As part of that set up, they will set environment + variables that contain the rank and size of the MPI_COMM_WORLD + communicator. We can read those environment variables from Python in order + to ensure that `mpi.rank()` and `mpi.size()` return the expected values. + + Since MPI is just a standard, not an implementation, implementations + typically choose their own environment variable names. This function tries + to support several different implementation, but really it only needs to + support whatever implementation we want to use for the TensorFlow test + suite. + + If this is not running under MPI, then defaults of rank zero and size one + are returned. (This is appropriate because when you call MPI_Init in an + application not started with mpirun, it will create a new independent + communicator with only one process in it.) + """ + rank_env = "PMI_RANK OMPI_COMM_WORLD_RANK".split() + size_env = "PMI_SIZE OMPI_COMM_WORLD_SIZE".split() + + for rank_var, size_var in zip(rank_env, size_env): + rank = os.environ.get(rank_var) + size = os.environ.get(size_var) + if rank is not None and size is not None: + return int(rank), int(size) + + # Default to rank zero and size one if there are no environment variables + return 0, 1 + + +class MPITests(tf.test.TestCase): + """ + Tests for MPI ops in tensorflow.contrib.mpi. + """ + + def test_mpi_rank(self): + """Test that the rank returned by mpi.rank() is correct.""" + true_rank, _ = mpi_env_rank_and_size() + with self.test_session() as session: + rank = session.run(mpi.rank()) + self.assertEqual(true_rank, rank) + + def test_mpi_size(self): + """Test that the size returned by mpi.size() is correct.""" + _, true_size = mpi_env_rank_and_size() + with self.test_session() as session: + size = session.run(mpi.size()) + self.assertEqual(true_size, size) + + def test_mpi_allreduce_cpu(self): + """Test on CPU that the allreduce correctly sums 1D, 2D, 3D tensors.""" + with self.test_session() as session: + size = session.run(mpi.size()) + + dtypes = [tf.int32, tf.float32] + dims = [1, 2, 3] + for dtype, dim in itertools.product(dtypes, dims): + tf.set_random_seed(1234) + tensor = tf.random_uniform([17] * dim, -100, 100, + dtype=dtype) + summed = mpi.allreduce(tensor, average=False) + multiplied = tensor * size + max_difference = tf.reduce_max(tf.abs(summed - multiplied)) + + # Threshold for floating point equality depends on number of + # ranks, since we're comparing against precise multiplication. + if size <= 3: + threshold = 0 + elif size < 10: + threshold = 1e-4 + elif size < 15: + threshold = 5e-4 + else: + break + + diff = session.run(max_difference) + self.assertTrue(diff <= threshold, + "mpi.allreduce produces incorrect results") + + def test_mpi_allreduce_gpu(self): + """Test that the allreduce works on GPUs. + + This test will crash badly if used with an MPI implementation that does + not support GPU memory transfers directly, as it will call MPI_Send on + a GPU data pointer.""" + # Only do this test if there are GPUs available. + if not tf.test.is_gpu_available(cuda_only=True): + return + + no_gpus = tf.GPUOptions(visible_device_list="") + cpu_config = tf.ConfigProto(gpu_options=no_gpus) + with self.test_session(config=cpu_config) as session: + local_rank = session.run(mpi.local_rank()) + + one_gpu = tf.GPUOptions(visible_device_list=str(local_rank)) + gpu_config = tf.ConfigProto(gpu_options=one_gpu) + with self.test_session(config=gpu_config) as session: + size = session.run(mpi.size()) + + dtype = tf.float32 + dim = 3 + with tf.device("/gpu:0"): + tf.set_random_seed(1234) + tensor = tf.random_uniform([17] * dim, -100, 100, dtype=dtype) + summed = mpi.allreduce(tensor, average=False) + multiplied = tensor * size + max_difference = tf.reduce_max(tf.abs(summed - multiplied)) + + # Threshold for floating point equality depends on number of + # ranks, since we're comparing against precise multiplication. + if size <= 3: + threshold = 0 + elif size < 10: + threshold = 1e-4 + elif size < 15: + threshold = 5e-4 + else: + return + + diff = session.run(max_difference) + self.assertTrue(diff <= threshold, + "mpi.allreduce on GPU produces incorrect results") + + def test_mpi_allreduce_error(self): + """Test that the allreduce raises an error if different ranks try to + send tensors of different rank or dimension.""" + with self.test_session() as session: + rank = session.run(mpi.rank()) + size = session.run(mpi.size()) + + # This test does not apply if there is only one worker. + if size == 1: + return + + # Same rank, different dimension + tf.set_random_seed(1234) + dims = [17 + rank] * 3 + tensor = tf.random_uniform(dims, -1.0, 1.0) + with self.assertRaises(tf.errors.FailedPreconditionError): + session.run(mpi.allreduce(tensor)) + + # Same number of elements, different rank + tf.set_random_seed(1234) + if rank == 0: + dims = [17, 23 * 57] + else: + dims = [17, 23, 57] + tensor = tf.random_uniform(dims, -1.0, 1.0) + with self.assertRaises(tf.errors.FailedPreconditionError): + session.run(mpi.allreduce(tensor)) + + def test_mpi_allreduce_type_error(self): + """Test that the allreduce raises an error if different ranks try to + send tensors of different type.""" + with self.test_session() as session: + rank = session.run(mpi.rank()) + size = session.run(mpi.size()) + + # This test does not apply if there is only one worker. + if size == 1: + return + + # Same rank, different dimension + dims = [17] * 3 + tensor = tf.ones(dims, + dtype=tf.int32 if rank % 2 == 0 else tf.float32) + with self.assertRaises(tf.errors.FailedPreconditionError): + session.run(mpi.allreduce(tensor)) + + def test_mpi_allgather(self): + """Test that the allgather correctly gathers 1D, 2D, 3D tensors.""" + with self.test_session() as session: + size = session.run(mpi.size()) + rank = session.run(mpi.rank()) + + dtypes = tf.int32, tf.float32 + dims = 1, 2, 3 + for dtype, dim in itertools.product(dtypes, dims): + tensor = tf.ones([17] * dim, dtype=dtype) * rank + gathered = mpi.allgather(tensor) + + gathered_tensor = session.run(gathered) + self.assertEqual(list(gathered_tensor.shape), + [17 * size] + [17] * (dim - 1)) + + for i in range(size): + rank_tensor = tf.slice(gathered_tensor, + [i * 17] + [0] * (dim - 1), + [17] + [-1] * (dim - 1)) + self.assertEqual(list(rank_tensor.shape), [17] * dim) + self.assertTrue( + session.run(tf.reduce_all(tf.equal(rank_tensor, i))), + "mpi.allgather produces incorrect gathered tensor") + + def test_mpi_allgather_variable_size(self): + """Test that the allgather correctly gathers 1D, 2D, 3D tensors, + even if those tensors have different sizes along the first dim.""" + with self.test_session() as session: + size = session.run(mpi.size()) + rank = session.run(mpi.rank()) + + dtypes = tf.int32, tf.float32 + dims = 1, 2, 3 + for dtype, dim in itertools.product(dtypes, dims): + # Support tests up to MPI Size of 35 + if size > 35: + break + + tensor_sizes = [17, 32, 81, 12, 15, 23, 22] * 5 + tensor_sizes = tensor_sizes[:size] + + tensor = tf.ones([tensor_sizes[rank]] + [17] * (dim - 1), + dtype=dtype) * rank + gathered = mpi.allgather(tensor) + + gathered_tensor = session.run(gathered) + expected_size = sum(tensor_sizes) + self.assertEqual(list(gathered_tensor.shape), + [expected_size] + [17] * (dim - 1)) + + for i in range(size): + rank_size = [tensor_sizes[i]] + [17] * (dim - 1) + rank_tensor = tf.slice( + gathered, [sum(tensor_sizes[:i])] + [0] * (dim - 1), + rank_size) + self.assertEqual(list(rank_tensor.shape), rank_size) + self.assertTrue( + session.run(tf.reduce_all(tf.equal(rank_tensor, i))), + "mpi.allgather produces incorrect gathered tensor") + + def test_mpi_allgather_error(self): + """Test that the allgather returns an error if any dimension besides + the first is different among the tensors being gathered.""" + with self.test_session() as session: + rank = session.run(mpi.rank()) + size = session.run(mpi.size()) + + # This test does not apply if there is only one worker. + if size == 1: + return + + tensor_size = [17] * 3 + tensor_size[1] = 10 * (rank + 1) + tensor = tf.ones(tensor_size, dtype=tf.float32) * rank + with self.assertRaises(tf.errors.FailedPreconditionError): + session.run(mpi.allgather(tensor)) + + def test_mpi_allgather_type_error(self): + """Test that the allgather returns an error if the types being gathered + differ among the processes""" + with self.test_session() as session: + rank = session.run(mpi.rank()) + size = session.run(mpi.size()) + + # This test does not apply if there is only one worker. + if size == 1: + return + + tensor_size = [17] * 3 + dtype = tf.int32 if rank % 2 == 0 else tf.float32 + tensor = tf.ones(tensor_size, dtype=dtype) * rank + with self.assertRaises(tf.errors.FailedPreconditionError): + session.run(mpi.allgather(tensor)) + + +if __name__ == '__main__': + tf.test.main() diff --git a/tensorflow/contrib/mpi/ring.cc b/tensorflow/contrib/mpi/ring.cc new file mode 100644 index 000000000..2c5da128f --- /dev/null +++ b/tensorflow/contrib/mpi/ring.cc @@ -0,0 +1,48 @@ +#define EIGEN_USE_THREADS + +#include "tensorflow/contrib/mpi/ring.h" + +namespace tensorflow { +namespace contrib { +namespace mpi { + +using CPUDevice = Eigen::ThreadPoolDevice; + +extern template MPI_Datatype MPIType(); +extern template MPI_Datatype MPIType(); +extern template DataType TensorFlowDataType(); +extern template DataType TensorFlowDataType(); + + +// Generate all necessary specializations for RingAllreduce. +template Status RingAllreduce(OpKernelContext*, Tensor&, Tensor*); +template Status RingAllreduce(OpKernelContext*, Tensor&, Tensor*); + +// Generate all necessary specializations for RingAllgather. +template Status RingAllgather( + OpKernelContext*, Tensor&, Tensor*, std::vector&); +template Status RingAllgather( + OpKernelContext*, Tensor&, Tensor*, std::vector&); + +// Copy data on a CPU using a straight-forward memcpy. +template<> void CopyTensorData(void* dst, void* src, size_t size) { + std::memcpy(dst, src, size); +}; + +// Accumulate values on a CPU. +template<> void AccumulateTensorData( + float* dst, float* src, size_t size) { + for(unsigned int i = 0; i < size; i++) { + dst[i] += src[i]; + } +}; +template<> void AccumulateTensorData( + int* dst, int* src, size_t size) { + for(unsigned int i = 0; i < size; i++) { + dst[i] += src[i]; + } +} + +} +} +} diff --git a/tensorflow/contrib/mpi/ring.cu.cc b/tensorflow/contrib/mpi/ring.cu.cc new file mode 100644 index 000000000..99795e7c7 --- /dev/null +++ b/tensorflow/contrib/mpi/ring.cu.cc @@ -0,0 +1,68 @@ + +#if GOOGLE_CUDA + +#define EIGEN_USE_GPU + +#include "tensorflow/contrib/mpi/ring.h" + +namespace tensorflow { +namespace contrib { +namespace mpi { + +using CPUDevice = Eigen::ThreadPoolDevice; + +template<> MPI_Datatype MPIType() { return MPI_FLOAT; }; +template<> MPI_Datatype MPIType() { return MPI_INT; }; + +template<> DataType TensorFlowDataType() { return DT_FLOAT; }; +template<> DataType TensorFlowDataType() { return DT_INT32; }; + +// Generate all necessary specializations for RingAllreduce. +template Status RingAllreduce(OpKernelContext*, Tensor&, Tensor*); +template Status RingAllreduce(OpKernelContext*, Tensor&, Tensor*); + +// Generate all necessary specializations for RingAllgather. +template Status RingAllgather( + OpKernelContext*, Tensor&, Tensor*, std::vector&); +template Status RingAllgather( + OpKernelContext*, Tensor&, Tensor*, std::vector&); + +// Synchronously copy data on the GPU, using a different stream than the default +// and than TensorFlow to avoid synchronizing on operations unrelated to the +// allreduce. +template<> void CopyTensorData(void* dst, void* src, size_t size) { + auto stream = CudaStreamForMPI(); + cudaMemcpyAsync(dst, src, size, cudaMemcpyDeviceToDevice, stream); + cudaStreamSynchronize(stream); +}; + +// Elementwise accumulation kernel for GPU. +template +__global__ void elemwise_accum(T* out, const T* in, const size_t N) { + for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; + i < N; + i += blockDim.x * gridDim.x) { + out[i] += in[i]; + } +} + +// Synchronously accumulate tensors on the GPU, using a different stream than +// the default and than TensorFlow to avoid synchronizing on operations +// unrelated to the allreduce. +template<> void AccumulateTensorData( + float* dst, float* src, size_t size) { + auto stream = CudaStreamForMPI(); + elemwise_accum<<<32, 256, 0, stream>>>(dst, src, size); + cudaStreamSynchronize(stream); +}; +template<> void AccumulateTensorData( + int* dst, int* src, size_t size) { + auto stream = CudaStreamForMPI(); + elemwise_accum<<<32, 256, 0, stream>>>(dst, src, size); + cudaStreamSynchronize(stream); +}; + +} +} +} +#endif diff --git a/tensorflow/contrib/mpi/ring.h b/tensorflow/contrib/mpi/ring.h new file mode 100644 index 000000000..091cd7a8d --- /dev/null +++ b/tensorflow/contrib/mpi/ring.h @@ -0,0 +1,330 @@ +#ifndef TENSORFLOW_CONTRIB_MPI_H_ +#define TENSORFLOW_CONTRIB_MPI_H_ + +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/shape_inference.h" + +#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" +#include "tensorflow/core/framework/tensor_types.h" + +#if GOOGLE_CUDA +#include "cuda_runtime.h" +#endif + +// Needed to avoid header issues with C++-supporting MPI implementations +#define OMPI_SKIP_MPICXX +#include "third_party/mpi/mpi.h" + +#define TAG_TENSOR 12 + +namespace tensorflow { +namespace contrib { +namespace mpi { + +using CPUDevice = Eigen::ThreadPoolDevice; +using GPUDevice = Eigen::GpuDevice; + +// Convert from templated types to values we can pass to MPI. +template +MPI_Datatype MPIType(); + +// Convert from templated types to TensorFlow data types. +template +DataType TensorFlowDataType(); + +#define MPI_REQUIRES_OK(MPI_STATUS) \ + if((MPI_STATUS) != MPI_SUCCESS) { \ + return errors::Unknown("MPI operation failed unexpectedly."); \ + } + +// Copy data from one tensor to another tensor. +// This uses a custom CUDA stream on GPU, which is necessary to overlay the +// backpropagation computations with the allreduce. +template +void CopyTensorData(void* destination, void* source, size_t size); + +// Add a tensor into another tensor, accumulating in place. +// This uses a custom CUDA stream on GPU, which is necessary to overlay the +// backpropagation computations with the allreduce. +template +void AccumulateTensorData(T* destination, T* source, size_t size); + +// We need to get the right stream for doing CUDA memory transfers and +// operations, which is possibly different from the standard TensorFlow stream. +#if GOOGLE_CUDA +cudaStream_t CudaStreamForMPI(); +#endif + +/* Perform a ring allreduce on the data. Allocate the necessary output tensor and + * store it in the output parameter. + * + * Assumes that all MPI processes are doing an allreduce of the same tensor, + * with the same dimensions. + * + * A ring allreduce is a bandwidth-optimal way to do an allreduce. To do the allreduce, + * the nodes involved are arranged in a ring: + * + * .--0--. + * / \ + * 3 1 + * \ / + * *--2--* + * + * Each node always sends to the next clockwise node in the ring, and receives + * from the previous one. + * + * The allreduce is done in two parts: a scatter-reduce and an allgather. In + * the scatter reduce, a reduction is done, so that each node ends up with a + * chunk of the final output tensor which has contributions from all other + * nodes. In the allgather, those chunks are distributed among all the nodes, + * so that all nodes have the entire output tensor. + * + * Both of these operations are done by dividing the input tensor into N + * evenly sized chunks (where N is the number of nodes in the ring). + * + * The scatter-reduce is done in N-1 steps. In the ith step, node j will send + * the (j - i)th chunk and receive the (j - i - 1)th chunk, adding it in to + * its existing data for that chunk. For example, in the first iteration with + * the ring depicted above, you will have the following transfers: + * + * Segment 0: Node 0 --> Node 1 + * Segment 1: Node 1 --> Node 2 + * Segment 2: Node 2 --> Node 3 + * Segment 3: Node 3 --> Node 0 + * + * In the second iteration, you'll have the following transfers: + * + * Segment 0: Node 1 --> Node 2 + * Segment 1: Node 2 --> Node 3 + * Segment 2: Node 3 --> Node 0 + * Segment 3: Node 0 --> Node 1 + * + * After this iteration, Node 2 has 3 of the four contributions to Segment 0. + * The last iteration has the following transfers: + * + * Segment 0: Node 2 --> Node 3 + * Segment 1: Node 3 --> Node 0 + * Segment 2: Node 0 --> Node 1 + * Segment 3: Node 1 --> Node 2 + * + * After this iteration, Node 3 has the fully accumulated Segment 0; Node 0 + * has the fully accumulated Segment 1; and so on. The scatter-reduce is complete. + * + * Next, the allgather distributes these fully accumululated chunks across all nodes. + * Communication proceeds in the same ring, once again in N-1 steps. At the ith step, + * node j will send chunk (j - i + 1) and receive chunk (j - i). For example, at the + * first iteration, the following transfers will occur: + * + * Segment 0: Node 3 --> Node 0 + * Segment 1: Node 0 --> Node 1 + * Segment 2: Node 1 --> Node 2 + * Segment 3: Node 2 --> Node 3 + * + * After the first iteration, Node 0 will have a fully accumulated Segment 0 + * (from Node 3) and Segment 1. In the next iteration, Node 0 will send its + * just-received Segment 0 onward to Node 1, and receive Segment 3 from Node 3. + * After this has continued for N - 1 iterations, all nodes will have a the fully + * accumulated tensor. + * + * Each node will do (N-1) sends for the scatter-reduce and (N-1) sends for the allgather. + * Each send will contain K / N bytes, if there are K bytes in the original tensor on every node. + * Thus, each node sends and receives 2K(N - 1)/N bytes of data, and the performance of the allreduce + * (assuming no latency in connections) is constrained by the slowest interconnect between the nodes. + * + */ +template +Status RingAllreduce(OpKernelContext* context, Tensor& input, Tensor* output) { + // Acquire MPI size and rank + int n, r; + MPI_REQUIRES_OK(MPI_Comm_size(MPI_COMM_WORLD, &n)); + MPI_REQUIRES_OK(MPI_Comm_rank(MPI_COMM_WORLD, &r)); + + // Allocate a new output tensor and copy data to it. + Status status = context->allocate_temp(TensorFlowDataType(), input.shape(), output); + if(!status.ok()) { + return status; + } + T* buffer = (T*) output->tensor_data().data(); + CopyTensorData((void*) buffer, + (void*) input.tensor_data().data(), + output->tensor_data().size()); + + // Calculate segment sizes and segment ends + const size_t elements_to_reduce = input.NumElements(); + const size_t segment_size = elements_to_reduce / n; + std::vector segment_sizes(n, segment_size); + + const size_t residual = elements_to_reduce % n; + for (size_t i = 0; i < residual; ++i) { + segment_sizes[i]++; + } + + std::vector segment_ends(n); + segment_ends[0] = segment_sizes[0]; + for (size_t i = 1; i < segment_ends.size(); ++i) { + segment_ends[i] = segment_sizes[i] + segment_ends[i-1]; + } + + assert(segment_ends[n-1] == elements_to_reduce); + + // Allocate temporary buffer - we know the first segment size is the + // largest. + tensorflow::TensorShape shape; + tensorflow::Tensor temp; + shape.AddDim(segment_sizes[0]); + status = context->allocate_temp(TensorFlowDataType(), shape, &temp); + if(!status.ok()) { + return status; + } + T* segment_recv = (T*) temp.tensor_data().data(); + + // Receive from your left neighbor with wrap-around + const size_t recv_from = ((r - 1) + n) % n; + + // Send to your right neighbor with wrap-around + const size_t send_to = (r + 1) % n; + + MPI_Status recv_status; + MPI_Request recv_req; + + // Now start ring. At every step, for every rank, we iterate through + // segments with wraparound and send and recv from our neighbors and reduce + // locally. At the i'th iteration, rank r, sends segment (r-i) and receives + // segment (r-i-1). + for (int i = 0; i < n - 1; i++) { + T* segment_send = &(buffer[segment_ends[((r-i) + n) % n] - + segment_sizes[((r-i) + n) % n]]); + + MPI_REQUIRES_OK(MPI_Irecv(segment_recv, segment_sizes[((r-i-1) + n) % n], + MPIType(), recv_from, TAG_TENSOR, MPI_COMM_WORLD, &recv_req)); + + MPI_REQUIRES_OK(MPI_Send(segment_send, segment_sizes[((r-i) + n) % n], + MPIType(), send_to, TAG_TENSOR, MPI_COMM_WORLD)); + + T *segment_update = &(buffer[segment_ends[((r-i-1) + n) % n] - + segment_sizes[((r-i-1) + n) % n]]); + + // Wait for recv to complete before reduction + MPI_Wait(&recv_req, &recv_status); + + const int N = segment_sizes[((r-i-1) + n) % n]; + auto recv = temp.Slice(0, segment_sizes[((r-i-1) + n) % n]); + AccumulateTensorData( + segment_update, (T*) recv.tensor_data().data(), N); + } + + // Now start pipelined ring allgather. At every step, for every rank, we + // iterate through segments with wraparound and send and recv from our + // neighbors. At the i'th iteration, rank r, sends segment (r+1-i) and + // receives segment (r-i). + for (size_t i = 0; i < n - 1; ++i) { + // Segment to send - at every iteration we send segment (r+1-i) + T* segment_send = &(buffer[segment_ends[((r+1-i) + n) % n] - + segment_sizes[((r+1-i) + n) % n]]); + + // Segment to recv - at every iteration we receive segment (r-i) + T* segment_recv = &(buffer[segment_ends[((r-i) + n) % n] - + segment_sizes[((r-i) + n) % n]]); + MPI_REQUIRES_OK(MPI_Sendrecv(segment_send, segment_sizes[((r+1-i) + n) % n], + MPIType(), send_to, TAG_TENSOR, segment_recv, + segment_sizes[((r-i) + n) % n], MPIType(), recv_from, + TAG_TENSOR, MPI_COMM_WORLD, &recv_status)); + } + + return Status::OK(); +} + +// Perform a ring allgather on a Tensor. Other ranks may allgather with a +// tensor which differs in the first dimension only; all other dimensions must +// be the same. +// +// For more information on the ring allgather, read the documentation for the +// ring allreduce, which includes a ring allgather. +template +Status RingAllgather(OpKernelContext* context, Tensor& input, Tensor* output, + std::vector& sizes) { + // Acquire MPI size and rank + int n, r; + MPI_REQUIRES_OK(MPI_Comm_size(MPI_COMM_WORLD, &n)); + MPI_REQUIRES_OK(MPI_Comm_rank(MPI_COMM_WORLD, &r)); + + assert(sizes.size() == n); + assert(input.dim_size(0) == sizes[r]); + + // Compute output shape: all dimensions identical, except first, which is + // the sum of all the input tensor sizes. + size_t total_dimension_size = 0; + for(auto dim : sizes) { + total_dimension_size += dim; + } + + tensorflow::TensorShape output_shape; + output_shape.AddDim(total_dimension_size); + for(int i = 1; i < input.shape().dims(); i++) { + output_shape.AddDim(input.dim_size(i)); + } + + // Compute number of elements in every "row". We can't compute number of + // elements in every chunks, because those chunks are variable length. + size_t elements_per_row = 1; + for(int i = 1; i < input.shape().dims(); i++) { + elements_per_row *= input.dim_size(i); + } + + Status status = context->allocate_temp(TensorFlowDataType(), output_shape, output); + if(!status.ok()) { + return status; + } + + // Copy data from input tensor to correct place in output tensor. + std::vector segment_starts(sizes.size()); + segment_starts[0] = 0; + for(int i = 1; i < n; i++) { + segment_starts[i] = segment_starts[i - 1] + elements_per_row * sizes[i - 1]; + } + size_t offset = segment_starts[r]; + + // Copy data to the right offset for this rank. + T* buffer = (T*) output->tensor_data().data(); + CopyTensorData((void*) (buffer + offset), + (void*) input.tensor_data().data(), + elements_per_row * sizes[r] * sizeof(T)); + + // Receive from your left neighbor with wrap-around + const size_t recv_from = ((r - 1) + n) % n; + + // Send to your right neighbor with wrap-around + const size_t send_to = (r + 1) % n; + + // Perform a ring allgather. At every step, for every rank, we iterate + // through segments with wraparound and send and recv from our neighbors. + // At the i'th iteration, rank r, sends segment (r-i) and receives segment (r-1-i). + MPI_Status recv_status; + for (size_t i = 0; i < n - 1; ++i) { + // Segment to send - at every iteration we send segment (r-i) + size_t offset_send = segment_starts[(r - i + n) % n]; + size_t rows_send = sizes[(r - i + n) % n]; + T* segment_send = &(buffer[offset_send]); + + // Segment to recv - at every iteration we receive segment (r-1-i) + size_t offset_recv = segment_starts[(r - i - 1 + n) % n]; + size_t rows_recv = sizes[(r - i - 1 + n) % n]; + T* segment_recv = &(buffer[offset_recv]); + + MPI_REQUIRES_OK(MPI_Sendrecv( + segment_send, elements_per_row * rows_send, + MPIType(), send_to, TAG_TENSOR, segment_recv, + elements_per_row * rows_recv, MPIType(), recv_from, + TAG_TENSOR, MPI_COMM_WORLD, &recv_status)); + } + + return Status::OK(); +} + +} +} +} + +#undef TENSORFLOW_CONTRIB_MPI_H_ +#endif // TENSORFLOW_CONTRIB_MPI_H_ diff --git a/tensorflow/tools/ci_build/Dockerfile.cpu.mpi b/tensorflow/tools/ci_build/Dockerfile.cpu.mpi new file mode 100644 index 000000000..d4a75a61f --- /dev/null +++ b/tensorflow/tools/ci_build/Dockerfile.cpu.mpi @@ -0,0 +1,24 @@ +FROM ubuntu:14.04 + +MAINTAINER Andrew Gibiansky + +# Copy and run the install scripts. +COPY install/*.sh /install/ +RUN /install/install_bootstrap_deb_packages.sh +RUN add-apt-repository -y ppa:openjdk-r/ppa && \ + add-apt-repository -y ppa:mc3man/trusty-media && \ + add-apt-repository -y ppa:george-edison55/cmake-3.x +RUN /install/install_deb_packages.sh +RUN /install/install_pip_packages.sh +RUN /install/install_bazel.sh +RUN /install/install_proto3.sh +RUN /install/install_buildifier.sh +RUN /install/install_mpi.sh + +# Set up bazelrc. +COPY install/.bazelrc /root/.bazelrc +ENV BAZELRC /root/.bazelrc + +# Set up MPI +ENV TF_NEED_MPI 1 +ENV MPI_PATH /usr/lib/openmpi diff --git a/tensorflow/tools/ci_build/install/install_mpi.sh b/tensorflow/tools/ci_build/install/install_mpi.sh new file mode 100755 index 000000000..6ee9d7659 --- /dev/null +++ b/tensorflow/tools/ci_build/install/install_mpi.sh @@ -0,0 +1,23 @@ +#!/usr/bin/env bash +# Copyright 2015 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +set +e +mpiexec=$(which mpiexec) +if [[ -z "$mpiexec_location" ]]; then + # Install dependencies from ubuntu deb repository. + apt-get update + apt-get install -y --no-install-recommends openmpi-bin libopenmpi-dev +fi diff --git a/third_party/mpi/.gitignore b/third_party/mpi/.gitignore new file mode 100644 index 000000000..ab011617a --- /dev/null +++ b/third_party/mpi/.gitignore @@ -0,0 +1,3 @@ +*.h +*.dylib +*.so diff --git a/third_party/mpi/BUILD b/third_party/mpi/BUILD new file mode 100644 index 000000000..cbae02431 --- /dev/null +++ b/third_party/mpi/BUILD @@ -0,0 +1,26 @@ +package(default_visibility = ["//visibility:public"]) + +licenses(["notice"]) # Apache 2.0 + +exports_files(["LICENSE.txt"]) + +filegroup( + name = "all_files", + srcs = glob( + ["**/*"], + exclude = [ + "**/METADATA", + "**/OWNERS", + ], + ), + visibility = ["//tensorflow:__subpackages__"], +) + +cc_library( + name = "mpi", + srcs = select({ + "//tensorflow:darwin": ["libmpi.dylib"], + "//conditions:default": ["libmpi.so"], + }), + hdrs = ["mpi.h", "mpi_portable_platform.h"], +)