Skip to content

Commit

Permalink
Allreduce: Rebase to TF 1.3-rc1 (#3)
Browse files Browse the repository at this point in the history
* Introduce MPI allreduce in a new contrib project.

This commit adds the tensorflow.contrib.mpi namespace and contrib
project, which has a variety of ops that work with MPI.

The MPI system works by starting a background thread which communicates
between the different processes at a regular interval and schedules
asynchronous reductions. At every tick, every rank will notify rank zero
of the tensors it is ready to reduce, signifying completion with an
empty DONE message. Rank zero will count how many ranks are ready to
reduce every tensor, and, whenever a tensor is ready to reduce (that is,
every rank is ready to reduce it), rank zero will issue a message to all
other ranks directing them to reduce that tensor.  This repeats for all
the tensors that are ready to reduce, after which rank zero sends all
other ranks a DONE message indicating that the tick is complete.

Reviewed-by: Joel Hestness <[email protected]>

* Allreduce/Allgather: Major changes and fixes (#2)

This commit constitutes many major updates to the TF MPI allreduce and
allgather ops. Specifically, the following changes are included in this
commit:
1) The allreduce and allgather ops had race conditions, which this commit
fixes. Specifically, the BackgroundThreadLoop previously allocated temporary
and output tensors after the main graph traversal thread has completed its
call to MPIAll*::ComputeAsync(). Unfortunately, the ops kernel context's
memory allocator is only guaranteed to be valid during the ComputeAsync call.
This constraint requires ComputeAsync to allocate all tensors before
returning; Otherwise, the memory allocator state may reflect allocations and
deallocations from further ops that can cause races for the memory locations.
To fix this, hoist the memory allocations to ComputeAsync. In this process,
introduce a collective op record, which tracks the parameters of the op (e.g.
input, output, and configurations).

2) Many models require capability to allreduce or allgather int64 tensors. We
add functionality to handle long long data type (64-bit ints).

3) Eliminate the thread sleep. A major to-do item is to eliminate the need for
polling between coordinator threads and other ranks. This change will require
the coordinator rank to be able to wake up all other ranks when a collective
is ready to be performed, but also for all ranks (i.e. background threads) to
be woken up by graph traversal threads. In the meantime, remove the thread
sleep, because it introduces significant run time overhead (e.g. >20%) for
models with quick-running layers (e.g. few recurrent time-steps or few hidden
nodes per layer).

* mpi_ops.cc: Move toward more TF nature

This commit changes a few bits and pieces to align more closely with
Tensorflow structures and organization:

1) Use TF mutexes. TF mutexes provide nice scoping and management around
std::mutex, and using them is consistent with other TF code.

2) Remove thread sleep at MPI initialization time. Thread sleep should not
be used for polling activity. Instead, this commit replaces sleep-polling
with a condition variable: The compute graph traversal thread waits on the
condition variable until the background thread has completed initialization
and signals the graph traversal thread that initialization is complete.

3) Slim MPI initialization check: Since TF permits many threads to be
traversing the compute graph concurrently (e.g. with
inter_op_parallelism_threads > 1), some graph traversal threads may not
have set their GPU device ID. If such a thread executes an MPI op, it would
fail the check in InitializedMPIOnSameDevice, because the background thread
would be controlling a GPU with ID other than the default (0). Since graph
traversal threads do not perform GPU activity, this GPU ID check was
unnecessary. Remove it and refactor to just check whether MPI is
initialized (IsMPIInitialized).

* Rebase to TF 1.3.0-rc1 complete and tested
  • Loading branch information
jthestness authored and GitHub Enterprise committed Aug 14, 2017
1 parent 6f0d70e commit 66d5b85
Show file tree
Hide file tree
Showing 17 changed files with 2,803 additions and 2 deletions.
4 changes: 3 additions & 1 deletion tensorflow/contrib/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@ licenses(["notice"]) # Apache 2.0

package(default_visibility = ["//tensorflow:__subpackages__"])

load("//third_party/mpi:mpi.bzl", "if_mpi")

py_library(
name = "contrib_py",
srcs = glob(["**/*.py"]),
Expand Down Expand Up @@ -78,7 +80,7 @@ py_library(
"//tensorflow/contrib/tpu:tpu_py",
"//tensorflow/contrib/training:training_py",
"//tensorflow/contrib/util:util_py",
],
] + if_mpi(["//tensorflow/contrib/mpi_collectives:mpi_ops_py"]),
)

cc_library(
Expand Down
2 changes: 1 addition & 1 deletion tensorflow/contrib/learn/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -1050,7 +1050,7 @@ py_test(
":learn",
"//tensorflow/contrib/layers:layers_py",
"//tensorflow/contrib/session_bundle:exporter",
"//tensorflow/contrib/session_bundle:manifest_proto_py_pb2",
"//tensorflow/contrib/session_bundle:manifest_proto_py",
"//tensorflow/python:array_ops",
"//tensorflow/python:client",
"//tensorflow/python:client_testlib",
Expand Down
72 changes: 72 additions & 0 deletions tensorflow/contrib/mpi_collectives/BUILD
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
# Ops that communicate with other processes via MPI.

package(default_visibility = [
"//tensorflow:__subpackages__",
])

licenses(["notice"]) # Apache 2.0

filegroup(
name = "all_files",
srcs = glob(
["**/*"],
exclude = [
"**/METADATA",
"**/OWNERS",
],
),
visibility = ["//tensorflow:__subpackages__"],
)

load(
"//tensorflow/core:platform/default/build_config.bzl",
"tf_proto_library_cc",
)

tf_proto_library_cc(
name = "mpi_message_proto",
srcs = ["mpi_message.proto"],
cc_api_version = 2,
visibility = [
"//tensorflow:__subpackages__",
],
)

load("//tensorflow:tensorflow.bzl", "tf_custom_op_library")
load("//tensorflow:tensorflow.bzl", "tf_py_test")

tf_custom_op_library(
name = "mpi_collectives.so",
srcs = ["mpi_ops.cc", "ring.cc", "ring.h"],
gpu_srcs = ["ring.cu.cc", "ring.h"],
deps = [
"//third_party/mpi:mpi",
":mpi_message_proto_cc",
],
)

tf_py_test(
name = "mpi_ops_test",
srcs = ["mpi_ops_test.py"],
additional_deps = [
"//tensorflow:tensorflow_py",
"//tensorflow/python:platform",
],
data = [
":mpi_collectives.so",
],
tags = ["manual"],
)

py_library(
name = "mpi_ops_py",
srcs = [
"__init__.py",
"mpi_ops.py",
],
data = [
":mpi_collectives.so",
],
srcs_version = "PY2AND3",
visibility = ["//visibility:public"],
)
5 changes: 5 additions & 0 deletions tensorflow/contrib/mpi_collectives/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
# MPI TensorFlow integration

Tensorflow MPI integration allows communicating between different TensorFlow
processes using MPI. This enables training across multiple nodes and GPUs
using high-speed interconnects.
273 changes: 273 additions & 0 deletions tensorflow/contrib/mpi_collectives/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,273 @@
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
# pylint: disable=g-short-docstring-punctuation
"""## Communicating Between Processes with MPI
TensorFlow natively provides inter-device communication through send and
receive ops and inter-node communication through Distributed TensorFlow, based
on the same send and receive abstractions. On HPC clusters where Infiniband or
other high-speed node interconnects are available, these can end up being
insufficient for synchronous data-parallel training (without asynchronous
gradient descent). This module implements a variety of MPI ops which can take
advantage of hardware-specific MPI libraries for efficient communication.
In order to use this module, TensorFlow must be built with an MPI library,
which can be provided to the `./configure` script at build time. As a user of
TensorFlow, you will need to build TensorFlow yourself to select the MPI
library to use; to do so, follow the [instructions for building TensorFlow from
source](https://www.tensorflow.org/get_started/os_setup#installing_from_sources).
### Utility Ops
In addition to reductions and gathers, this module provides utility operations
for detecting the running MPI configuration.
Example:
```python
from tensorflow.contrib import mpi
# Use `mpi.Session` instead of `tf.Session`
with mpi.Session() as session:
rank = session.run(mpi.rank())
print("My MPI Rank:", rank)
if rank == 0:
print("MPI Size:", session.run(mpi.size()))
```
@@rank
@@size
### Ring Allreduce and Allgather
When summing or averaging tensors across many processes, communication can
easily become a bottleneck. A naive implementation will send all the tensor
values to the same process, perform the reduction, and then broadcast the
values back to all other processes, effectively creating a synchronous
parameter server in one process. However, the process responsible for
performing the reduction will have to receive and send a massive amount of data
which scales with the number of processes *and* the number of parameters in the
model.
Instead of centralizing the reduction and having one primary reducer, we can
implement a distributed allreduce or allgather. A bandwidth-optimal allreduce
will end up sending 2(N - 1) values for every value in the input tensor,
and can be implemented with a ring allreduce [1]. (Intuitively, a linear reduce
requires at least (N - 1) sends between the different nodes, and a broadcast of
the result also requires (N - 1) sends, for a total of 2 (N - 1); these two
steps cannot be combined in a clever way to reduce the number of required
sends.) This module implements bandwidth-optimal ring allreduce and ring
allgather operations using MPI; by choosing a hardware-appropriate MPI
implementation (such as OpenMPI with CUDA-IPC support), you can train large
models with synchronous gradient descent with minimal communication overhead.
In addition to the `allreduce` and `allgather` functions, a convenience
`DistributedOptimizer` wrapper is provided to simplify using these functions
for reducing model gradients.
Example:
```python
import tensorflow as tf
from tensorflow.contrib import mpi_collectives as mpi
# Construct a simple linear regression model to optimize
W = tf.get_variable("W", shape=[20, 1], dtype=tf.float32)
B = tf.get_variable("B", shape=[1, 1], dtype=tf.float32)
inputs = tf.placeholder("Inputs", shape=[None, 20])
outputs = tf.placeholder("Outputs", shape=[None, 1])
loss = tf.nn.l2_loss(tf.matmul(inputs, W) + B - outputs)
# Training using MPI allreduce with DistributedOptimizer
optimizer = mpi.DistributedOptimizer(tf.train.AdamOptimizer())
train = optimizer.minimize(loss)
# Average loss over all ranks, for printing.
# Do not pass this to an optimizer!
avg_loss = mpi.allreduce(loss)
# On different ranks, feed different input data.
with mpi.Session() as session:
rank = session.run(mpi.rank())
batch_inputs, batch_outputs = construct_batch_for_rank(rank)
feed_dict = {inputs: batch_inputs, outputs: batch_outputs}
_, l = session.run([train, avg_loss], feed_dict=feed_dict)
print("Average Loss:", l)
```
[1] Patarasuk, Pitch and Yuan, Xin. "Bandwidth Optimal All-reduce Algorithms
for Clusters of Workstations".
@@Session
@@DistributedOptimizer
@@allreduce
@@allgather
"""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import tensorflow as tf

from tensorflow.contrib.mpi_collectives.mpi_ops import size
from tensorflow.contrib.mpi_collectives.mpi_ops import rank
from tensorflow.contrib.mpi_collectives.mpi_ops import local_rank
from tensorflow.contrib.mpi_collectives.mpi_ops import allgather
from tensorflow.contrib.mpi_collectives.mpi_ops import _allreduce
from tensorflow.contrib.mpi_collectives.mpi_ops import init


def allreduce(tensor, average=True):
"""Perform an MPI allreduce on a tf.Tensor or tf.IndexedSlices.
Arguments:
tensor: tf.Tensor, tf.Variable, or tf.IndexedSlices to reduce.
The shape of the input must be identical across all ranks.
average: If True, computes the average over all ranks.
Otherwise, computes the sum over all ranks.
This function performs a bandwidth-optimal ring allreduce on the input
tensor. If the input is an tf.IndexedSlices, the function instead does an
allgather on the values and the indices, effectively doing an allreduce on
the represented tensor.
"""
if isinstance(tensor, tf.IndexedSlices):
# For IndexedSlices, do two allgathers intead of an allreduce.
mpi_size = tf.cast(size(), tensor.values.dtype)
values = allgather(tensor.values)
indices = allgather(tensor.indices)

# To make this operation into an average, divide all gathered values by
# the MPI size.
new_values = tf.div(values, mpi_size) if average else values
return tf.IndexedSlices(new_values, indices,
dense_shape=tensor.dense_shape)
else:
mpi_size = tf.cast(size(), tensor.dtype)
summed_tensor = _allreduce(tensor)
new_tensor = (tf.div(summed_tensor, mpi_size)
if average else summed_tensor)
return new_tensor


class DistributedOptimizer(tf.train.Optimizer):
"""An optimizer that wraps another tf.Optimizer, using an MPI allreduce to
average gradient values before applying gradients to model weights."""

def __init__(self, optimizer, name=None, use_locking=False):
"""Construct a new DistributedOptimizer, which uses another optimizer
under the hood for computing single-process gradient values and
applying gradient updates after the gradient values have been averaged
across all the MPI ranks.
Args:
optimizer:
Optimizer to use for computing gradients and applying updates.
name:
Optional name prefix for the operations created when applying
gradients. Defaults to "Distributed" followed by the provided
optimizer type.
use_locking:
Whether to use locking when updating variables.
See Optimizer.__init__ for more info.
"""
if name is None:
name = "Distributed{}".format(type(optimizer).__name__)

self._optimizer = optimizer
super(DistributedOptimizer, self).__init__(
name=name, use_locking=use_locking)

def compute_gradients(self, *args, **kwargs):
"""Compute gradients of all trainable variables.
See Optimizer.compute_gradients() for more info.
In DistributedOptimizer, compute_gradients() is overriden to also
allreduce the gradients before returning them.
"""
gradients = (super(DistributedOptimizer, self)
.compute_gradients(*args, **kwargs))
return [(allreduce(gradient), var) for (gradient, var) in gradients]

def _apply_dense(self, *args, **kwargs):
"""Calls this same method on the underlying optimizer."""
return self._optimizer._apply_dense(*args, **kwargs)

def _apply_sparse(self, *args, **kwargs):
"""Calls this same method on the underlying optimizer."""
return self._optimizer._apply_sparse(*args, **kwargs)

def _apply_sparse_duplicate_indices(self, *args, **kwargs):
"""Calls this same method on the underlying optimizer."""
return self._optimizer._apply_sparse_duplicate_indices(*args,
**kwargs)

def _prepare(self, *args, **kwargs):
"""Calls this same method on the underlying optimizer."""
return self._optimizer._prepare(*args, **kwargs)

def _create_slots(self, *args, **kwargs):
"""Calls this same method on the underlying optimizer."""
return self._optimizer._create_slots(*args, **kwargs)

def _valid_dtypes(self, *args, **kwargs):
"""Calls this same method on the underlying optimizer."""
return self._optimizer._valid_dtypes(*args, **kwargs)

def _finish(self, *args, **kwargs):
"""Calls this same method on the underlying optimizer."""
return self._optimizer._finish(*args, **kwargs)


class Session(tf.Session):
"""A class for running TensorFlow operations, with copies of the same graph
running distributed across different MPI nodes.
The primary difference between `tf.Session` and `tf.contrib.mpi.Session` is
that the MPI `Session` ensures that the `Session` options are correct for
use with `tf.contrib.mpi`, and initializes MPI immediately upon the start
of the session.
"""

def __init__(self, target='', graph=None, config=None):
"""Creates a new TensorFlow MPI session.
Unlike a normal `tf.Session`, an MPI Session may only use a single GPU,
which must be specified in advance before the session is initialized.
In addition, it only uses a single graph evaluation thread, and
initializes MPI immediately upon starting.
If no `graph` argument is specified when constructing the session,
the default graph will be launched in the session. If you are
using more than one graph (created with `tf.Graph()` in the same
process, you will have to use different sessions for each graph,
but each graph can be used in multiple sessions. In this case, it
is often clearer to pass the graph to be launched explicitly to
the session constructor.
Args:
target: (Optional.) The execution engine to connect to.
graph: (Optional.) The `Graph` to be launched (described above).
config: (Optional.) A `ConfigProto` protocol buffer with configuration
options for the session.
"""
super(Session, self).__init__(target, graph, config=config)

# Initialize MPI on the relevant device.
# TODO: Move this to library load and eliminate mpi.Session()
self.run(init())
Loading

0 comments on commit 66d5b85

Please sign in to comment.