From 159cc5c5260a10c9c46af3b91a7901505e77b106 Mon Sep 17 00:00:00 2001 From: Tomasz Grel Date: Thu, 22 Jun 2017 08:44:57 +0200 Subject: [PATCH] Initial commit --- README.md | 47 +++++++ examples/cifar-10/README.md | 20 +++ examples/cifar-10/download_data.sh | 4 + examples/cifar-10/job.sh | 5 + examples/cifar-10/main.py | 146 +++++++++++++++++++++ examples/cifar-10/my_job.sh | 5 + setup.py | 11 ++ tensorflow_on_slurm/__init__.py | 3 + tensorflow_on_slurm/tensorflow_on_slurm.py | 69 ++++++++++ tensorflow_on_slurm/tests/test.py | 71 ++++++++++ 10 files changed, 381 insertions(+) create mode 100644 README.md create mode 100644 examples/cifar-10/README.md create mode 100644 examples/cifar-10/download_data.sh create mode 100644 examples/cifar-10/job.sh create mode 100644 examples/cifar-10/main.py create mode 100644 examples/cifar-10/my_job.sh create mode 100644 setup.py create mode 100644 tensorflow_on_slurm/__init__.py create mode 100644 tensorflow_on_slurm/tensorflow_on_slurm.py create mode 100644 tensorflow_on_slurm/tests/test.py diff --git a/README.md b/README.md new file mode 100644 index 0000000..04f4942 --- /dev/null +++ b/README.md @@ -0,0 +1,47 @@ +# Tensorflow on Slurm + +This package makes it easier to run distributed TensorFlow jobs on slurm clusters. It contains functions for parsing the Slurm environment variables in order to create configuration for distributed TF. + +## Prerequisites + +You need to have TensorFlow installed. All the examples were tested with TensorFlow 1.0.1, but other versions also have a good chance of working correctly. + +### Installation + +To install execute the following on the command line: +``` +git clone +cd tensorflow_on_slurm +sudo pip install . +``` + +## Usage + +A complete usage example using the CIFAR-10 dataset is included in the examples directory. + +However if you just want to dive in you can paste the following snippet into your script: + +```python +import tensorflow as tf +from tensorflow_on_slurm import tf_config_from_slurm + +cluster, my_job_name, my_task_index = tf_config_from_slurm(ps_number=1) +cluster_spec = tf.train.ClusterSpec(cluster) +server = tf.train.Server(server_or_cluster_def=cluster_spec, + job_name=my_job_name, + task_index=my_task_index) +``` +## Issues + +It's possible that our tests don't cover all the corner cases about the names of the Slurm nodes etc. If you happen to spot some bugs please don't hesitate to file an issue here on github. + +## Contributing +Pull request are more than welcome. If you'd like to add some new functionality don't forget to write unit tests for it and make sure you don't break any currently working tests (see below on how to run the tests) + +## Tests +To run the tests issue: + +``` +python tensorflow_on_slurm/tests/test.py +``` + diff --git a/examples/cifar-10/README.md b/examples/cifar-10/README.md new file mode 100644 index 0000000..479e600 --- /dev/null +++ b/examples/cifar-10/README.md @@ -0,0 +1,20 @@ +# CIFAR-10 Example + +This directory contains an example job training a simple CNN model for CIFAR-10 on a Slurm cluster with distributed TensorFlow. Please note that the example is meant only to illustrate the usage of this package, the accuracy of the model could certainly be easily improved. + +## Running the example + +Firstly, create a virtualenv and install TensorFlow and tensorflow_on_slurm in it with: + +``` +virtualenv cifar-10-job-venv +source venv/bin/activate +pip install tensorflow +pip install git+https://github.com/deepsense-io/tensorflow_on_slurm +``` + +The virtualenv is necessary to make sure the packages are there on all the nodes taking part in the job. Now to run the example script on 5 nodes (4 workers + 1 parameter server) issue: + +``` +srun -N 5 -n 5 -t 24:00:00 job.sh cifar-10-job-venv +``` diff --git a/examples/cifar-10/download_data.sh b/examples/cifar-10/download_data.sh new file mode 100644 index 0000000..f7e95d1 --- /dev/null +++ b/examples/cifar-10/download_data.sh @@ -0,0 +1,4 @@ +#! /bin/bash + +wget https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz . +tar -xvzf cifar-10-python.tar.gz diff --git a/examples/cifar-10/job.sh b/examples/cifar-10/job.sh new file mode 100644 index 0000000..7d1b359 --- /dev/null +++ b/examples/cifar-10/job.sh @@ -0,0 +1,5 @@ +#! /bin/bash +echo "SLURM_JOB_ID " $SLURM_JOB_ID "; SLURM_JOB_NAME " $SLURM_JOB_NAME "; SLURM_JOB_NODELIST " $SLURM_JOB_NODELIST "; SLURMD_NODENAME " $SLURMD_NODENAME "; SLURM_JOB_NUM_NODES " $SLURM_JOB_NUM_NODES +echo $1 +source "$1"/bin/activate +python main.py diff --git a/examples/cifar-10/main.py b/examples/cifar-10/main.py new file mode 100644 index 0000000..24551a4 --- /dev/null +++ b/examples/cifar-10/main.py @@ -0,0 +1,146 @@ +from __future__ import print_function +from __future__ import absolute_import +from __future__ import division + +import pickle +import os +import tensorflow as tf +import numpy as np +import sys +import time +from tensorflow_on_slurm import tf_config_from_slurm + +cluster, my_job_name, my_task_index = tf_config_from_slurm(ps_number=1) +cluster_spec = tf.train.ClusterSpec(cluster) +server = tf.train.Server(server_or_cluster_def=cluster_spec, + job_name=my_job_name, + task_index=my_task_index) + +if my_job_name == 'ps': + server.join() + sys.exit(0) + +data_dir = 'cifar-10-batches-py' +filelist = [os.path.join(data_dir, 'data_batch_1'), + os.path.join(data_dir, 'data_batch_2'), + os.path.join(data_dir, 'data_batch_3'), + os.path.join(data_dir, 'data_batch_4'), + os.path.join(data_dir, 'data_batch_5')] + +data, labels = [], [] +is_chief = my_task_index == 0 + +for f in filelist: + with open(f, 'rb') as fo: + data_elem = pickle.load(fo) + data.append(data_elem['data']) + labels.extend(data_elem['labels']) +data = np.vstack(d for d in data) +print('data shape: ', data.shape) + +def weight_variable(shape): + with tf.device("/job:ps/task:0"): + initial = tf.truncated_normal(shape, stddev=0.1) + v = tf.Variable(initial) + return v + +def bias_variable(shape): + with tf.device("/job:ps/task:0"): + initial = tf.constant(0.1, shape=shape) + v = tf.Variable(initial) + return v + +def conv2d(x, W): + return tf.nn.conv2d(x, W, strides=[1, 1, 1, 1], padding='VALID') + +def max_pool_2x2(x): + return tf.nn.max_pool(x, ksize=[1, 2, 2, 1], + strides=[1, 2, 2, 1], padding='SAME') + +with tf.device('/job:worker/task:{}'.format(my_task_index)): + x = tf.placeholder(tf.float32, shape=[None, 3072], name='x') + y = tf.placeholder(tf.uint8, shape=[None, 1], name='y') + + # FIRST CONVOLUTIONAL LAYER + y_one_hot = tf.one_hot(indices=y, depth=10) + + ks = 5 + n_filters1 = 16 + W_conv1 = weight_variable([ks, ks, 3, n_filters1]) + b_conv1 = bias_variable([n_filters1]) + + reshaped = tf.reshape(x, [-1, 3, 32, 32]) + transposed = tf.transpose(reshaped, [0, 2, 3, 1]) + x_image = (transposed - 128) / 128 + + h_conv1 = tf.nn.relu(conv2d(x_image, W_conv1) + b_conv1) + h_pool1 = max_pool_2x2(h_conv1) + + # SECOND CONVOLUTIONAL LAYER + n_filters2 = 64 + W_conv2 = weight_variable([ks, ks, n_filters1, n_filters2]) + b_conv2 = bias_variable([n_filters2]) + + h_conv2 = tf.nn.relu(conv2d(h_pool1, W_conv2) + b_conv2) + h_pool2 = max_pool_2x2(h_conv2) + + # FULLY CONNECTED LAYER + hidden_neurons = 512 + W_fc1 = weight_variable([5 * 5 * n_filters2, hidden_neurons]) + b_fc1 = bias_variable([hidden_neurons]) + + h_pool2_flat = tf.reshape(h_pool2, [-1, 5 * 5 * n_filters2]) + h_fc1 = tf.nn.relu(tf.matmul(h_pool2_flat, W_fc1) + b_fc1) + + # DROPOUT + keep_prob = tf.placeholder(tf.float32) + h_fc1_drop = tf.nn.dropout(h_fc1, keep_prob) + + # SOFTMAX + W_fc2 = weight_variable([hidden_neurons, 10]) + b_fc2 = bias_variable([10]) + + y_conv = tf.matmul(h_fc1_drop, W_fc2) + b_fc2 + cross_entropy = tf.nn.softmax_cross_entropy_with_logits(logits=y_conv, labels=y_one_hot) + loss = tf.reduce_mean(cross_entropy) + opt = tf.train.AdamOptimizer(1e-3) + opt = tf.train.SyncReplicasOptimizer(opt, replicas_to_aggregate=len(cluster['worker']), + total_num_replicas=len(cluster['worker'])) + global_step = bias_variable([]) + train_step = opt.minimize(loss, global_step=global_step) + sync_replicas_hook = opt.make_session_run_hook(is_chief) + + y_hat = tf.round(tf.argmax(tf.nn.softmax(y_conv), 1)) + y_hat = tf.cast(y_hat, tf.uint8) + y_hat = tf.reshape(y_hat, [-1, 1]) + correct_prediction = tf.equal(y_hat, y) + accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32)) + +def batch_generator(data, labels, batch_size=32): + x_batch, y_batch = [], [] + for d, l in zip(data, labels): + x_batch.append(d) + y_batch.append(l) + if len(x_batch) == batch_size: + yield np.vstack(x_batch),np.vstack(y_batch) + x_batch = [] + y_batch = [] + +epochs = 1000 +batch_size = 128 +step = 0 +sess = tf.train.MonitoredTrainingSession(master=server.target, is_chief=is_chief, + hooks=[sync_replicas_hook]) + +for i in range(epochs): + bg = batch_generator(data, labels, batch_size) + for j, (data_batch, label_batch) in enumerate(bg): + if (j+i) % len(cluster['worker']) != my_task_index: + continue + _, loss_, acc = sess.run([train_step, loss, accuracy], + feed_dict={x: data_batch, + y: label_batch.reshape(-1,1), + keep_prob: 0.5}) + step += 1 + print(step, my_task_index, loss_, acc) + sys.stdout.flush() diff --git a/examples/cifar-10/my_job.sh b/examples/cifar-10/my_job.sh new file mode 100644 index 0000000..c9f06ff --- /dev/null +++ b/examples/cifar-10/my_job.sh @@ -0,0 +1,5 @@ +#! /bin/bash +echo "SLURM_JOB_ID " $SLURM_JOB_ID "; SLURM_JOB_NAME " $SLURM_JOB_NAME "; SLURM_JOB_NODELIST " $SLURM_JOB_NODELIST "; SLURMD_NODENAME " $SLURMD_NODENAME "; SLURM_JOB_NUM_NODES " $SLURM_JOB_NUM_NODES +source /net/people/plgtgrel/simple_venv/bin/activate +python main.py +echo "DONE" diff --git a/setup.py b/setup.py new file mode 100644 index 0000000..94e4f10 --- /dev/null +++ b/setup.py @@ -0,0 +1,11 @@ +#!/usr/bin/env python +from distutils.core import setup + +setup(name='tensorflow_on_slurm', + version='0.1', + description='Helps in running distributed TensorFlow on Slurm clusters', + author='Tomasz Grel', + author_email='tomasz.grel@deepsense.io', + url='', + packages=['tensorflow_on_slurm'], + ) diff --git a/tensorflow_on_slurm/__init__.py b/tensorflow_on_slurm/__init__.py new file mode 100644 index 0000000..7d415d8 --- /dev/null +++ b/tensorflow_on_slurm/__init__.py @@ -0,0 +1,3 @@ +from .tensorflow_on_slurm import tf_config_from_slurm + +__all__ = ['tf_config_from_slurm'] diff --git a/tensorflow_on_slurm/tensorflow_on_slurm.py b/tensorflow_on_slurm/tensorflow_on_slurm.py new file mode 100644 index 0000000..9393b4b --- /dev/null +++ b/tensorflow_on_slurm/tensorflow_on_slurm.py @@ -0,0 +1,69 @@ +from __future__ import print_function +from __future__ import absolute_import +from __future__ import division + +import tensorflow as tf +import re +import os + +def tf_config_from_slurm(ps_number, port_number=2222): + """ + Creates configuration for a distributed tensorflow session + from environment variables provided by the Slurm cluster + management system. + + @param: ps_number number of parameter servers to run + @param: port_number port number to be used for communication + @return: a tuple containing cluster with fields cluster_spec, + task_name and task_id + """ + + nodelist = os.environ["SLURM_JOB_NODELIST"] + nodename = os.environ["SLURMD_NODENAME"] + nodelist = _expand_nodelist(nodelist) + num_nodes = int(os.getenv("SLURM_JOB_NUM_NODES")) + + if len(nodelist) != num_nodes: + raise ValueError("Number of slurm nodes {} not equal to {}".format(len(nodelist), num_nodes)) + + if nodename not in nodelist: + raise ValueError("Nodename({}) not in nodelist({}). This should not happen! ".format(nodename,nodelist)) + + ps_nodes = [node for i, node in enumerate(nodelist) if i < ps_number] + worker_nodes = [node for i, node in enumerate(nodelist) if i >= ps_number] + + if nodename in ps_nodes: + my_job_name = "ps" + my_task_index = ps_nodes.index(nodename) + else: + my_job_name = "worker" + my_task_index = worker_nodes.index(nodename) + + worker_sockets = [":".join([node, str(port_number)]) for node in worker_nodes] + ps_sockets = [":".join([node, str(port_number)]) for node in ps_nodes] + cluster = {"worker": worker_sockets, "ps" : ps_sockets} + + return cluster, my_job_name, my_task_index + +def _pad_zeros(iterable, length): + return (str(t).rjust(length, '0') for t in iterable) + +def _expand_ids(ids): + ids = ids.split(',') + result = [] + for id in ids: + if '-' in id: + begin, end = [int(token) for token in id.split('-')] + result.extend(_pad_zeros(range(begin, end+1), len(token))) + else: + result.append(id) + return result + +def _expand_nodelist(nodelist): + prefix, ids = re.findall("(.*)\[(.*)\]", nodelist)[0] + ids = _expand_ids(ids) + result = [prefix + str(id) for id in ids] + return result + +def _worker_task_id(nodelist, nodename): + return nodelist.index(nodename) diff --git a/tensorflow_on_slurm/tests/test.py b/tensorflow_on_slurm/tests/test.py new file mode 100644 index 0000000..7d21976 --- /dev/null +++ b/tensorflow_on_slurm/tests/test.py @@ -0,0 +1,71 @@ +from __future__ import print_function +from __future__ import absolute_import +from __future__ import division + +import sys +import os +sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..')) + +import unittest + +from tensorflow_on_slurm import tf_config_from_slurm, _expand_ids, _expand_nodelist, _worker_task_id + +class BasicTestData(unittest.TestCase): + def setUp(self): + unittest.TestCase.setUp(self) + self.nodelist = 'p[1135,1137-1142,1147-1148,1152]' + self.first_nodename = 'p1135' + self.nodename = 'p1140' + self.nodes_number = 10 + +class ShortNodenameTestData(unittest.TestCase): + def setUp(self): + unittest.TestCase.setUp(self) + self.nodelist = 'p[0900-0910]' + self.first_nodename = 'p0900' + self.nodename = 'p0902' + self.nodes_number = 11 + +class ShortNodenameTestData2(unittest.TestCase): + def setUp(self): + unittest.TestCase.setUp(self) + self.nodelist = 'p[0900,0910]' + self.first_nodename = 'p0900' + self.nodename = 'p0910' + self.nodes_number = 2 + +class TensorflowSlurmUtilsTest(object): + def test_expand_ids(self): + test_ids = '1-5,7,8-12' + res = _expand_ids(test_ids) + + def test_expand_nodelist(self): + expanded = _expand_nodelist(self.nodelist) + self.assertEqual(len(expanded), self.nodes_number) + self.assertIn(self.nodename, expanded) + + def test_first_task_id(self): + expanded = _expand_nodelist(self.nodelist) + first_task_id = _worker_task_id(expanded, self.first_nodename) + self.assertEqual(first_task_id, 0) + + def test_other_task_id(self): + expanded = _expand_nodelist(self.nodelist) + task_id = _worker_task_id(expanded, self.nodename) + self.assertIn(task_id, range(self.nodes_number)) + + def test_tf_config_from_slurm(self): + os.environ["SLURM_JOB_NODELIST"] = self.nodelist + os.environ["SLURMD_NODENAME"] = self.nodename + os.environ["SLURM_JOB_NUM_NODES"] = str(self.nodes_number) + cluster, my_job_name, my_task_index = tf_config_from_slurm(ps_number=2) + +class BasicTestCase(BasicTestData, TensorflowSlurmUtilsTest): + pass +class ShortNodenameTestCase(ShortNodenameTestData, TensorflowSlurmUtilsTest): + pass +class ShortNodenameTestCase2(ShortNodenameTestData2, TensorflowSlurmUtilsTest): + pass + +if __name__ == '__main__': + unittest.main()