Skip to content

Commit

Permalink
Initial commit
Browse files Browse the repository at this point in the history
  • Loading branch information
Tomasz Grel committed Jun 22, 2017
0 parents commit 159cc5c
Show file tree
Hide file tree
Showing 10 changed files with 381 additions and 0 deletions.
47 changes: 47 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
# Tensorflow on Slurm

This package makes it easier to run distributed TensorFlow jobs on slurm clusters. It contains functions for parsing the Slurm environment variables in order to create configuration for distributed TF.

## Prerequisites

You need to have TensorFlow installed. All the examples were tested with TensorFlow 1.0.1, but other versions also have a good chance of working correctly.

### Installation

To install execute the following on the command line:
```
git clone
cd tensorflow_on_slurm
sudo pip install .
```

## Usage

A complete usage example using the CIFAR-10 dataset is included in the examples directory.

However if you just want to dive in you can paste the following snippet into your script:

```python
import tensorflow as tf
from tensorflow_on_slurm import tf_config_from_slurm

cluster, my_job_name, my_task_index = tf_config_from_slurm(ps_number=1)
cluster_spec = tf.train.ClusterSpec(cluster)
server = tf.train.Server(server_or_cluster_def=cluster_spec,
job_name=my_job_name,
task_index=my_task_index)
```
## Issues

It's possible that our tests don't cover all the corner cases about the names of the Slurm nodes etc. If you happen to spot some bugs please don't hesitate to file an issue here on github.

## Contributing
Pull request are more than welcome. If you'd like to add some new functionality don't forget to write unit tests for it and make sure you don't break any currently working tests (see below on how to run the tests)

## Tests
To run the tests issue:

```
python tensorflow_on_slurm/tests/test.py
```

20 changes: 20 additions & 0 deletions examples/cifar-10/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
# CIFAR-10 Example

This directory contains an example job training a simple CNN model for CIFAR-10 on a Slurm cluster with distributed TensorFlow. Please note that the example is meant only to illustrate the usage of this package, the accuracy of the model could certainly be easily improved.

## Running the example

Firstly, create a virtualenv and install TensorFlow and tensorflow_on_slurm in it with:

```
virtualenv cifar-10-job-venv
source venv/bin/activate
pip install tensorflow
pip install git+https://github.com/deepsense-io/tensorflow_on_slurm
```

The virtualenv is necessary to make sure the packages are there on all the nodes taking part in the job. Now to run the example script on 5 nodes (4 workers + 1 parameter server) issue:

```
srun -N 5 -n 5 -t 24:00:00 job.sh cifar-10-job-venv
```
4 changes: 4 additions & 0 deletions examples/cifar-10/download_data.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
#! /bin/bash

wget https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz .
tar -xvzf cifar-10-python.tar.gz
5 changes: 5 additions & 0 deletions examples/cifar-10/job.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
#! /bin/bash
echo "SLURM_JOB_ID " $SLURM_JOB_ID "; SLURM_JOB_NAME " $SLURM_JOB_NAME "; SLURM_JOB_NODELIST " $SLURM_JOB_NODELIST "; SLURMD_NODENAME " $SLURMD_NODENAME "; SLURM_JOB_NUM_NODES " $SLURM_JOB_NUM_NODES
echo $1
source "$1"/bin/activate
python main.py
146 changes: 146 additions & 0 deletions examples/cifar-10/main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,146 @@
from __future__ import print_function
from __future__ import absolute_import
from __future__ import division

import pickle
import os
import tensorflow as tf
import numpy as np
import sys
import time
from tensorflow_on_slurm import tf_config_from_slurm

cluster, my_job_name, my_task_index = tf_config_from_slurm(ps_number=1)
cluster_spec = tf.train.ClusterSpec(cluster)
server = tf.train.Server(server_or_cluster_def=cluster_spec,
job_name=my_job_name,
task_index=my_task_index)

if my_job_name == 'ps':
server.join()
sys.exit(0)

data_dir = 'cifar-10-batches-py'
filelist = [os.path.join(data_dir, 'data_batch_1'),
os.path.join(data_dir, 'data_batch_2'),
os.path.join(data_dir, 'data_batch_3'),
os.path.join(data_dir, 'data_batch_4'),
os.path.join(data_dir, 'data_batch_5')]

data, labels = [], []
is_chief = my_task_index == 0

for f in filelist:
with open(f, 'rb') as fo:
data_elem = pickle.load(fo)
data.append(data_elem['data'])
labels.extend(data_elem['labels'])
data = np.vstack(d for d in data)
print('data shape: ', data.shape)

def weight_variable(shape):
with tf.device("/job:ps/task:0"):
initial = tf.truncated_normal(shape, stddev=0.1)
v = tf.Variable(initial)
return v

def bias_variable(shape):
with tf.device("/job:ps/task:0"):
initial = tf.constant(0.1, shape=shape)
v = tf.Variable(initial)
return v

def conv2d(x, W):
return tf.nn.conv2d(x, W, strides=[1, 1, 1, 1], padding='VALID')

def max_pool_2x2(x):
return tf.nn.max_pool(x, ksize=[1, 2, 2, 1],
strides=[1, 2, 2, 1], padding='SAME')

with tf.device('/job:worker/task:{}'.format(my_task_index)):
x = tf.placeholder(tf.float32, shape=[None, 3072], name='x')
y = tf.placeholder(tf.uint8, shape=[None, 1], name='y')

# FIRST CONVOLUTIONAL LAYER
y_one_hot = tf.one_hot(indices=y, depth=10)

ks = 5
n_filters1 = 16
W_conv1 = weight_variable([ks, ks, 3, n_filters1])
b_conv1 = bias_variable([n_filters1])

reshaped = tf.reshape(x, [-1, 3, 32, 32])
transposed = tf.transpose(reshaped, [0, 2, 3, 1])
x_image = (transposed - 128) / 128

h_conv1 = tf.nn.relu(conv2d(x_image, W_conv1) + b_conv1)
h_pool1 = max_pool_2x2(h_conv1)

# SECOND CONVOLUTIONAL LAYER
n_filters2 = 64
W_conv2 = weight_variable([ks, ks, n_filters1, n_filters2])
b_conv2 = bias_variable([n_filters2])

h_conv2 = tf.nn.relu(conv2d(h_pool1, W_conv2) + b_conv2)
h_pool2 = max_pool_2x2(h_conv2)

# FULLY CONNECTED LAYER
hidden_neurons = 512
W_fc1 = weight_variable([5 * 5 * n_filters2, hidden_neurons])
b_fc1 = bias_variable([hidden_neurons])

h_pool2_flat = tf.reshape(h_pool2, [-1, 5 * 5 * n_filters2])
h_fc1 = tf.nn.relu(tf.matmul(h_pool2_flat, W_fc1) + b_fc1)

# DROPOUT
keep_prob = tf.placeholder(tf.float32)
h_fc1_drop = tf.nn.dropout(h_fc1, keep_prob)

# SOFTMAX
W_fc2 = weight_variable([hidden_neurons, 10])
b_fc2 = bias_variable([10])

y_conv = tf.matmul(h_fc1_drop, W_fc2) + b_fc2
cross_entropy = tf.nn.softmax_cross_entropy_with_logits(logits=y_conv, labels=y_one_hot)
loss = tf.reduce_mean(cross_entropy)
opt = tf.train.AdamOptimizer(1e-3)
opt = tf.train.SyncReplicasOptimizer(opt, replicas_to_aggregate=len(cluster['worker']),
total_num_replicas=len(cluster['worker']))
global_step = bias_variable([])
train_step = opt.minimize(loss, global_step=global_step)
sync_replicas_hook = opt.make_session_run_hook(is_chief)

y_hat = tf.round(tf.argmax(tf.nn.softmax(y_conv), 1))
y_hat = tf.cast(y_hat, tf.uint8)
y_hat = tf.reshape(y_hat, [-1, 1])
correct_prediction = tf.equal(y_hat, y)
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))

def batch_generator(data, labels, batch_size=32):
x_batch, y_batch = [], []
for d, l in zip(data, labels):
x_batch.append(d)
y_batch.append(l)
if len(x_batch) == batch_size:
yield np.vstack(x_batch),np.vstack(y_batch)
x_batch = []
y_batch = []

epochs = 1000
batch_size = 128
step = 0
sess = tf.train.MonitoredTrainingSession(master=server.target, is_chief=is_chief,
hooks=[sync_replicas_hook])

for i in range(epochs):
bg = batch_generator(data, labels, batch_size)
for j, (data_batch, label_batch) in enumerate(bg):
if (j+i) % len(cluster['worker']) != my_task_index:
continue
_, loss_, acc = sess.run([train_step, loss, accuracy],
feed_dict={x: data_batch,
y: label_batch.reshape(-1,1),
keep_prob: 0.5})
step += 1
print(step, my_task_index, loss_, acc)
sys.stdout.flush()
5 changes: 5 additions & 0 deletions examples/cifar-10/my_job.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
#! /bin/bash
echo "SLURM_JOB_ID " $SLURM_JOB_ID "; SLURM_JOB_NAME " $SLURM_JOB_NAME "; SLURM_JOB_NODELIST " $SLURM_JOB_NODELIST "; SLURMD_NODENAME " $SLURMD_NODENAME "; SLURM_JOB_NUM_NODES " $SLURM_JOB_NUM_NODES
source /net/people/plgtgrel/simple_venv/bin/activate
python main.py
echo "DONE"
11 changes: 11 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
#!/usr/bin/env python
from distutils.core import setup

setup(name='tensorflow_on_slurm',
version='0.1',
description='Helps in running distributed TensorFlow on Slurm clusters',
author='Tomasz Grel',
author_email='[email protected]',
url='',
packages=['tensorflow_on_slurm'],
)
3 changes: 3 additions & 0 deletions tensorflow_on_slurm/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from .tensorflow_on_slurm import tf_config_from_slurm

__all__ = ['tf_config_from_slurm']
69 changes: 69 additions & 0 deletions tensorflow_on_slurm/tensorflow_on_slurm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
from __future__ import print_function
from __future__ import absolute_import
from __future__ import division

import tensorflow as tf
import re
import os

def tf_config_from_slurm(ps_number, port_number=2222):
"""
Creates configuration for a distributed tensorflow session
from environment variables provided by the Slurm cluster
management system.
@param: ps_number number of parameter servers to run
@param: port_number port number to be used for communication
@return: a tuple containing cluster with fields cluster_spec,
task_name and task_id
"""

nodelist = os.environ["SLURM_JOB_NODELIST"]
nodename = os.environ["SLURMD_NODENAME"]
nodelist = _expand_nodelist(nodelist)
num_nodes = int(os.getenv("SLURM_JOB_NUM_NODES"))

if len(nodelist) != num_nodes:
raise ValueError("Number of slurm nodes {} not equal to {}".format(len(nodelist), num_nodes))

if nodename not in nodelist:
raise ValueError("Nodename({}) not in nodelist({}). This should not happen! ".format(nodename,nodelist))

ps_nodes = [node for i, node in enumerate(nodelist) if i < ps_number]
worker_nodes = [node for i, node in enumerate(nodelist) if i >= ps_number]

if nodename in ps_nodes:
my_job_name = "ps"
my_task_index = ps_nodes.index(nodename)
else:
my_job_name = "worker"
my_task_index = worker_nodes.index(nodename)

worker_sockets = [":".join([node, str(port_number)]) for node in worker_nodes]
ps_sockets = [":".join([node, str(port_number)]) for node in ps_nodes]
cluster = {"worker": worker_sockets, "ps" : ps_sockets}

return cluster, my_job_name, my_task_index

def _pad_zeros(iterable, length):
return (str(t).rjust(length, '0') for t in iterable)

def _expand_ids(ids):
ids = ids.split(',')
result = []
for id in ids:
if '-' in id:
begin, end = [int(token) for token in id.split('-')]
result.extend(_pad_zeros(range(begin, end+1), len(token)))
else:
result.append(id)
return result

def _expand_nodelist(nodelist):
prefix, ids = re.findall("(.*)\[(.*)\]", nodelist)[0]
ids = _expand_ids(ids)
result = [prefix + str(id) for id in ids]
return result

def _worker_task_id(nodelist, nodename):
return nodelist.index(nodename)
71 changes: 71 additions & 0 deletions tensorflow_on_slurm/tests/test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
from __future__ import print_function
from __future__ import absolute_import
from __future__ import division

import sys
import os
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..'))

import unittest

from tensorflow_on_slurm import tf_config_from_slurm, _expand_ids, _expand_nodelist, _worker_task_id

class BasicTestData(unittest.TestCase):
def setUp(self):
unittest.TestCase.setUp(self)
self.nodelist = 'p[1135,1137-1142,1147-1148,1152]'
self.first_nodename = 'p1135'
self.nodename = 'p1140'
self.nodes_number = 10

class ShortNodenameTestData(unittest.TestCase):
def setUp(self):
unittest.TestCase.setUp(self)
self.nodelist = 'p[0900-0910]'
self.first_nodename = 'p0900'
self.nodename = 'p0902'
self.nodes_number = 11

class ShortNodenameTestData2(unittest.TestCase):
def setUp(self):
unittest.TestCase.setUp(self)
self.nodelist = 'p[0900,0910]'
self.first_nodename = 'p0900'
self.nodename = 'p0910'
self.nodes_number = 2

class TensorflowSlurmUtilsTest(object):
def test_expand_ids(self):
test_ids = '1-5,7,8-12'
res = _expand_ids(test_ids)

def test_expand_nodelist(self):
expanded = _expand_nodelist(self.nodelist)
self.assertEqual(len(expanded), self.nodes_number)
self.assertIn(self.nodename, expanded)

def test_first_task_id(self):
expanded = _expand_nodelist(self.nodelist)
first_task_id = _worker_task_id(expanded, self.first_nodename)
self.assertEqual(first_task_id, 0)

def test_other_task_id(self):
expanded = _expand_nodelist(self.nodelist)
task_id = _worker_task_id(expanded, self.nodename)
self.assertIn(task_id, range(self.nodes_number))

def test_tf_config_from_slurm(self):
os.environ["SLURM_JOB_NODELIST"] = self.nodelist
os.environ["SLURMD_NODENAME"] = self.nodename
os.environ["SLURM_JOB_NUM_NODES"] = str(self.nodes_number)
cluster, my_job_name, my_task_index = tf_config_from_slurm(ps_number=2)

class BasicTestCase(BasicTestData, TensorflowSlurmUtilsTest):
pass
class ShortNodenameTestCase(ShortNodenameTestData, TensorflowSlurmUtilsTest):
pass
class ShortNodenameTestCase2(ShortNodenameTestData2, TensorflowSlurmUtilsTest):
pass

if __name__ == '__main__':
unittest.main()

0 comments on commit 159cc5c

Please sign in to comment.