Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Runner] Add query mode #113

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions flagscale/launcher/job_status.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
from enum import Enum


class JobStatus(Enum):
RUNNING = "Running"
TRANSITIONAL = "Transitional (Stopping or Starting)"
COMPLETED_OR_IDLE = "Completed or Not Started"
162 changes: 149 additions & 13 deletions flagscale/launcher/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,12 @@
import subprocess
import json
import uuid
import time
from datetime import datetime
from abc import ABC, abstractmethod
from omegaconf import DictConfig, OmegaConf
from ..logger import logger
from .job_status import JobStatus


def log_and_raise_error(message):
Expand Down Expand Up @@ -70,7 +72,8 @@ def run_local_command(cmd, dryrun=False):
logger.info(f"SSHRunner is running the local command: {cmd}")
if dryrun:
return
subprocess.run(cmd, shell=True, check=True)
result = subprocess.run(cmd, shell=True, check=True, capture_output=True, text=True)
return result


def run_ssh_command(host, cmd, port=None, dryrun=False):
Expand All @@ -81,7 +84,8 @@ def run_ssh_command(host, cmd, port=None, dryrun=False):
logger.info(f"SSHRunner is running the ssh command: {ssh_cmd}")
if dryrun:
return
subprocess.run(ssh_cmd, shell=True, check=True)
result = subprocess.run(cmd, shell=True, check=True, capture_output=True, text=True)
return result


def run_scp_command(host, src, dst, port=None, dryrun=False):
Expand Down Expand Up @@ -150,7 +154,7 @@ def _update_config(config: DictConfig):

if config.get("checkpoint", None) is None:
config.checkpoint = DictConfig({})

if config.get("logging", None) is None:
config.logging = DictConfig({})

Expand Down Expand Up @@ -246,7 +250,7 @@ def _get_runner_cmd(
runner_args["nnodes"] = nnodes
runner_args["node_rank"] = node_rank
runner_args["nproc_per_node"] = nproc_per_node
runner_args["rdzv_backend"] = rdzv_backend
runner_args["rdzv_backend"] = rdzv_backend
runner_args["rdzv_endpoint"] = rdzv_endpoint
runner_args["log_dir"] = (
log_dir if backend == "torchrun" else os.path.join(log_dir, rdzv_id)
Expand All @@ -269,14 +273,14 @@ def _get_nnodes(nnodes_from_hostfile=None, nnodes_from_args=None):
assert nnodes_from_hostfile is not None or nnodes_from_args is not None
if nnodes_from_hostfile is not None and nnodes_from_args is not None:
if isinstance(nnodes_from_args, str) and ":" in nnodes_from_args:
# Ignore the max nnodes from the args, no elastic support
# Ignore the max nnodes from the args, no elastic support
nnodes_from_args, _ = nnodes_from_args.split(":")
return min(nnodes_from_hostfile, int(nnodes_from_args))
elif nnodes_from_hostfile is not None:
return nnodes_from_hostfile
elif nnodes_from_args is not None:
if isinstance(nnodes_from_args, str) and ":" in nnodes_from_args:
# Ignore the max nnodes from the args, no elastic support
# Ignore the max nnodes from the args, no elastic support
nnodes_from_args, _ = nnodes_from_args.split(":")
return int(nnodes_from_args)

Expand Down Expand Up @@ -323,7 +327,9 @@ class SSHRunner(MultiNodeRunner):
def __init__(self, config: DictConfig):
self.config = config
_update_config(self.config)
self.resources = parse_hostfile(self.config.experiment.runner.get("hostfile", None))
self.resources = parse_hostfile(
self.config.experiment.runner.get("hostfile", None)
)

def _prepare(self):
self.rdzv_id = datetime.now().strftime("%Y%m%d_%H%M%S.%f")
Expand All @@ -332,7 +338,9 @@ def _prepare(self):
if self.config.experiment.task.type == "train":
self.user_args = get_megatron_args(self.config)
else:
raise ValueError(f"Unsupported task type: {self.config.experiment.task.type}")
raise ValueError(
f"Unsupported task type: {self.config.experiment.task.type}"
)

def _generate_run_script(self, host, node_rank, cmd, with_test=False):
system_config = self.config.train.system
Expand Down Expand Up @@ -360,7 +368,7 @@ def _generate_run_script(self, host, node_rank, cmd, with_test=False):
megatron_dir = os.path.join(root_dir, "megatron")
with open(host_run_script_file, "w") as f:
f.write("#!/bin/bash\n\n")
f.write('ulimit -n 1048576\n')
f.write("ulimit -n 1048576\n")
f.write(f"mkdir -p {system_config.checkpoint.load}\n")
f.write(f"mkdir -p {system_config.checkpoint.save}\n")
f.write(f"mkdir -p {system_config.logging.log_dir}\n")
Expand Down Expand Up @@ -423,19 +431,27 @@ def _run_each(
test_cmd = f";python tests/functional_tests/check_result.py {exp_dir};rm -r {exp_dir}"
cmd = cmd + test_cmd

host_run_script_file = self._generate_run_script(host, node_rank, cmd, with_test)
host_run_script_file = self._generate_run_script(
host, node_rank, cmd, with_test
)

logging_config = self.config.train.system.logging
if host != "localhost":
ssh_port = self.config.experiment.runner.get("ssh_port", 22)
# Step 1: make sure the scripts_dir exists on the remote host
run_ssh_command(host, f"mkdir -p {logging_config.scripts_dir}", ssh_port, dryrun)
run_ssh_command(
host, f"mkdir -p {logging_config.scripts_dir}", ssh_port, dryrun
)

# Step 2: copy the host_run_script_file to the remote host
no_shared_fs = self.config.experiment.runner.get("no_shared_fs", False)
if no_shared_fs:
run_scp_command(
host, host_run_script_file, logging_config.scripts_dir, ssh_port, dryrun
host,
host_run_script_file,
logging_config.scripts_dir,
ssh_port,
dryrun,
)

# Step 3: run the host_run_script_file on the remote host
Expand Down Expand Up @@ -486,7 +502,9 @@ def run(self, with_test=False, dryrun=False):
else:
# If hostfile is not provided, run the job on localhost
nproc_from_args = runner_config.get("nproc_per_node", None)
nproc_per_node = _get_nproc_per_node(None, nproc_from_args, num_visible_devices)
nproc_per_node = _get_nproc_per_node(
None, nproc_from_args, num_visible_devices
)
avaliable_addr = runner_config.get("master_addr", "localhost")
avaliable_port = runner_config.get("master_port", get_free_port())
self._run_each(
Expand Down Expand Up @@ -560,3 +578,121 @@ def stop(self):
if node_rank >= nnodes:
break
self._stop_each(host, node_rank)

def _generate_query_script(self, host, node_rank):
"""Genetrate the query script for each host."""
logging_config = self.config.train.system.logging

host_query_script_file = os.path.join(
logging_config.scripts_dir, f"host_{node_rank}_{host}_query.sh"
)

# Check if the host_query_script_file exists
if os.path.exists(host_query_script_file):
return host_query_script_file

host_pid_file = os.path.join(
logging_config.pids_dir, f"host_{node_rank}_{host}.pid"
)

os.makedirs(logging_config.scripts_dir, exist_ok=True)

with open(host_query_script_file, "w") as f:
f.write("#!/bin/bash\n\n")
f.write("if [ -f " + host_pid_file + " ]; then\n")
f.write(" pid=$(cat " + host_pid_file + ")\n")
f.write(" ps -p $pid -o state --no-headers\n")
f.write("else\n")
# TODO: This is a temporary fix. We need to find a better way to query the job.
f.write(
" pid=$(ps aux | grep 'torchrun' | grep -v grep | head -n 1 | awk '{print $2}')\n"
)
f.write(" ps -p $pid -o state --no-headers\n")
f.write("fi\n")
f.flush()
os.fsync(f.fileno())
os.chmod(host_query_script_file, 0o755)

return host_query_script_file

def _query_each(self, host, node_rank):
"Query each node status."
host_query_script_file = self._generate_query_script(host, node_rank)
logging_config = self.config.train.system.logging
result = ""
if host != "localhost":
ssh_port = self.config.experiment.runner.get("ssh_port", 22)
# Step 1: make sure the scripts_dir exists on the remote host
run_ssh_command(host, f"mkdir -p {logging_config.scripts_dir}", ssh_port)
# Step 2: copy the host_run_script_file to the remote host
no_shared_fs = self.config.experiment.runner.get("no_shared_fs", False)
if no_shared_fs:
run_scp_command(
host, host_query_script_file, logging_config.scripts_dir, ssh_port
)
# Step 3: run the host_run_script_file on the remote host
try:
result = run_ssh_command(
host, f"bash {host_query_script_file}", ssh_port
)
except Exception as e:
logger.error(f"Failed to query job status on {host}: {e}")
else:
try:
result = run_local_command(f"bash {host_query_script_file}")
except Exception as e:
logger.error(f"Failed to query job status on {host}: {e}")
result = result.stdout.rstrip() if result else ""
return result

def _query_status(self):
"Query Job status."
results = []
if self.resources is None:
result = self._query_each("localhost", 0)
results.append(result)

else:
host_list = list(self.resources.keys())
for host, _ in self.resources.items():
node_rank = host_list.index(host)
result = self._query_each(host, node_rank)
results.append(result)

if all(status != "" for status in results):
job_status = JobStatus.RUNNING
elif all(status == "" for status in results):
job_status = JobStatus.COMPLETED_OR_IDLE
else:
job_status = JobStatus.TRANSITIONAL
return job_status

def query(self, interval=10, timeout=None):
"""
Query job status and log.
There are three kinds of status for a Job:
RUNNING: The job is running.
COMPLETED_OR_IDLE: The job is completed or idle.
TRANSITIONAL: The job is starting or stopping.

Args:
interval (int, optional): The interval of querying job status. Default: 10.
timeout (float, optional): The timeout of query job status, if None, the query will keep indefinitely. Default: None.

Returns:
None

"""
if timeout is None:
while True:
job_status = self._query_status()
logger.info(f"Job status: {job_status.name}")
time.sleep(interval)
else:
start_time = time.time()
cur_time = time.time()
while cur_time - start_time < timeout:
job_status = self._query_status()
logger.info(f"Job status: {job_status.name}")
time.sleep(interval)
cur_time = time.time()
8 changes: 5 additions & 3 deletions run.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
import hydra
from omegaconf import DictConfig
from flagscale.logger import logger
from flagscale.launcher.runner import SSHRunner
from flagscale.launcher.runner import SSHRunner


@hydra.main(version_base=None, config_name="config")
def main(config : DictConfig) -> None:
def main(config: DictConfig) -> None:
runner = SSHRunner(config)

if config.action == "run":
Expand All @@ -16,9 +16,11 @@ def main(config : DictConfig) -> None:
runner.run(with_test=True)
elif config.action == "stop":
runner.stop()
elif config.action == "query":
runner.query()
else:
raise ValueError(f"Unknown action {config.action}")


if __name__ == "__main__":
main()
main()