Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Issue with list literal inner variable #5

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 14 additions & 13 deletions tensorflow_on_slurm/tensorflow_on_slurm.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,53 +8,54 @@

def tf_config_from_slurm(ps_number, port_number=2222):
"""
Creates configuration for a distributed tensorflow session
Creates configuration for a distributed tensorflow session
from environment variables provided by the Slurm cluster
management system.

@param: ps_number number of parameter servers to run
@param: port_number port number to be used for communication
@return: a tuple containing cluster with fields cluster_spec,
task_name and task_id
task_name and task_id
"""

nodelist = os.environ["SLURM_JOB_NODELIST"]
nodename = os.environ["SLURMD_NODENAME"]
nodelist = _expand_nodelist(nodelist)
num_nodes = int(os.getenv("SLURM_JOB_NUM_NODES"))

if len(nodelist) != num_nodes:
raise ValueError("Number of slurm nodes {} not equal to {}".format(len(nodelist), num_nodes))

if nodename not in nodelist:
raise ValueError("Nodename({}) not in nodelist({}). This should not happen! ".format(nodename,nodelist))

ps_nodes = [node for i, node in enumerate(nodelist) if i < ps_number]
worker_nodes = [node for i, node in enumerate(nodelist) if i >= ps_number]

if nodename in ps_nodes:
my_job_name = "ps"
my_task_index = ps_nodes.index(nodename)
else:
my_job_name = "worker"
my_task_index = worker_nodes.index(nodename)

worker_sockets = [":".join([node, str(port_number)]) for node in worker_nodes]
ps_sockets = [":".join([node, str(port_number)]) for node in ps_nodes]
cluster = {"worker": worker_sockets, "ps" : ps_sockets}

return cluster, my_job_name, my_task_index

def _pad_zeros(iterable, length):
return (str(t).rjust(length, '0') for t in iterable)

def _expand_ids(ids):
ids = ids.split(',')
result = []
for id in ids:
if '-' in id:
begin, end = [int(token) for token in id.split('-')]
result.extend(_pad_zeros(range(begin, end+1), len(token)))
split = id.split('-')
begin, end = [int(token) for token in split]
result.extend(_pad_zeros(range(begin, end+1), len(split[-1])))
else:
result.append(id)
return result
Expand Down