diff --git a/flagscale/auto_tuner/prune/history.py b/flagscale/auto_tuner/prune/history.py index 8c58a33e1..886897e23 100644 --- a/flagscale/auto_tuner/prune/history.py +++ b/flagscale/auto_tuner/prune/history.py @@ -30,7 +30,7 @@ def prune_by_micro_batch_size(config, strategy, history=[]): logger.info( f"The strategy {strategy} has been pruned by micro_batch_size performance." ) - strategy["performance"] = performance + strategy["performance"] = item["performance"] strategy["max_mem"] = item["max_mem"] strategy["pruned"] = True return True diff --git a/flagscale/launcher/runner.py b/flagscale/launcher/runner.py index 9dc9db8d7..f54a24efe 100644 --- a/flagscale/launcher/runner.py +++ b/flagscale/launcher/runner.py @@ -87,7 +87,7 @@ def get_host_name_or_ip(): return IP -def run_local_command(cmd, dryrun=False): +def run_local_command(cmd, dryrun=False, query=False): logger.info(f"Run the local command: {cmd}") if dryrun: return @@ -510,11 +510,6 @@ def _run_each( cmd = shlex.join(export_cmd + runner_cmd + [self.user_script] + self.user_args) - if with_test: - exp_dir = self.config.experiment.exp_dir - test_cmd = f";python tests/functional_tests/check_result.py {exp_dir};rm -r {exp_dir}" - cmd = cmd + test_cmd - host_run_script_file = _generate_run_script( self.config, host, node_rank, cmd, background=True, with_test=with_test ) @@ -693,7 +688,7 @@ def _query_each(self, host, node_rank): logger.error(f"Failed to query job status on {host}: {e}") else: try: - result = run_local_command(f"bash {host_query_script_file}") + result = run_local_command(f"bash {host_query_script_file}", query=True) except Exception as e: logger.error(f"Failed to query job status on {host}: {e}") result = result.stdout.rstrip() if result else "" @@ -795,11 +790,6 @@ def _run_each( cmd = shlex.join(export_cmd + runner_cmd + [self.user_script] + self.user_args) - if with_test: - exp_dir = self.config.experiment.exp_dir - test_cmd = f";python tests/functional_tests/check_result.py {exp_dir};rm -r {exp_dir}" - cmd = cmd + test_cmd - host_run_script_file = _generate_run_script( self.config, host, node_rank, cmd, background=False, with_test=with_test )