From e0c9afbe1393b9880f26b5108bf2d7dd7ce41ef6 Mon Sep 17 00:00:00 2001 From: Julian Klug Date: Thu, 1 Aug 2024 21:43:15 +0200 Subject: [PATCH] use storage_host in all subprocesses to access other nodes --- .../cluster/cluster_subprocess.py | 7 ++++--- .../cluster/master_launcher.py | 3 ++- .../cluster/subprocess.sbatch | 2 +- .../gridsearch_transformer.py | 7 ++++--- 4 files changed, 11 insertions(+), 8 deletions(-) diff --git a/prediction/short_term_outcome_prediction/cluster/cluster_subprocess.py b/prediction/short_term_outcome_prediction/cluster/cluster_subprocess.py index 3604b66..4c4a54a 100644 --- a/prediction/short_term_outcome_prediction/cluster/cluster_subprocess.py +++ b/prediction/short_term_outcome_prediction/cluster/cluster_subprocess.py @@ -11,14 +11,14 @@ def subprocess_cluster_gridsearch(data_splits_path:str, output_folder:str, trial_name:str, gridsearch_config_path: dict, use_gpu:bool=True, - storage_pwd:str=None, storage_port:int=None): + storage_pwd:str=None, storage_port:int=None, storage_host:str='localhost'): # load config with open(gridsearch_config_path, 'r') as f: gridsearch_config = json.load(f) if storage_pwd is not None and storage_port is not None: storage = optuna.storages.JournalStorage(optuna.storages.JournalRedisStorage( - url=f'redis://default:{storage_pwd}@localhost:{storage_port}/opsum' + url=f'redis://default:{storage_pwd}@{storage_host}:{storage_port}/opsum' )) else: storage = None @@ -43,10 +43,11 @@ def subprocess_cluster_gridsearch(data_splits_path:str, output_folder:str, trial parser.add_argument('-g', '--use_gpu', type=int, required=False, default=1) parser.add_argument('-spwd', '--storage_pwd', type=str, required=False, default=None) parser.add_argument('-sport', '--storage_port', type=int, required=False, default=None) + parser.add_argument('-shost', '--storage_host', type=str, required=False, default='localhost') args = parser.parse_args() use_gpu = args.use_gpu == 1 subprocess_cluster_gridsearch(args.data_splits_path, args.output_folder, args.trial_name, args.gridsearch_config_path, use_gpu=use_gpu, - storage_pwd=args.storage_pwd, storage_port=args.storage_port) \ No newline at end of file + storage_pwd=args.storage_pwd, storage_port=args.storage_port, storage_host=args.storage_host) \ No newline at end of file diff --git a/prediction/short_term_outcome_prediction/cluster/master_launcher.py b/prediction/short_term_outcome_prediction/cluster/master_launcher.py index 7b528b5..5c9b393 100644 --- a/prediction/short_term_outcome_prediction/cluster/master_launcher.py +++ b/prediction/short_term_outcome_prediction/cluster/master_launcher.py @@ -48,7 +48,8 @@ def launch_cluster_gridsearch(data_splits_path: str, output_folder: str, for i in range(n_subprocesses): os.system(f'sbatch --export=ALL,data_splits_path={data_splits_path},output_folder={output_folder},' f'trial_name={study_name},gridsearch_config_path={gridsearch_config_path},use_gpu={use_gpu},' - f'storage_pwd={storage_pwd},storage_port={storage_port},subprocess_py_file_path={subprocess_py_file_path} {subprocess_sbatch_file_path}') + f'storage_pwd={storage_pwd},storage_port={storage_port},storage_host={storage_host},' + f'subprocess_py_file_path={subprocess_py_file_path} {subprocess_sbatch_file_path}') if __name__ == '__main__': diff --git a/prediction/short_term_outcome_prediction/cluster/subprocess.sbatch b/prediction/short_term_outcome_prediction/cluster/subprocess.sbatch index e1f8707..35a923d 100644 --- a/prediction/short_term_outcome_prediction/cluster/subprocess.sbatch +++ b/prediction/short_term_outcome_prediction/cluster/subprocess.sbatch @@ -22,5 +22,5 @@ conda activate opsum export PYTHONPATH="${PYTHONPATH}:/home/users/k/klug/opsum" export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:$CONDA_PREFIX/lib/ -srun python $subprocess_py_file_path -d $data_splits_path -o $output_folder -t $trial_name -c $gridsearch_config_path -g $use_gpu -spwd $storage_pwd -sport $storage_port +srun python $subprocess_py_file_path -d $data_splits_path -o $output_folder -t $trial_name -c $gridsearch_config_path -g $use_gpu -spwd $storage_pwd -sport $storage_port -shost $storage_host cp $OPSUM_LOGS_PATH $output_folder diff --git a/prediction/short_term_outcome_prediction/gridsearch_transformer.py b/prediction/short_term_outcome_prediction/gridsearch_transformer.py index de479ed..f77bb7c 100644 --- a/prediction/short_term_outcome_prediction/gridsearch_transformer.py +++ b/prediction/short_term_outcome_prediction/gridsearch_transformer.py @@ -38,7 +38,7 @@ } def launch_gridsearch(data_splits_path:str, output_folder:str, gridsearch_config:dict=DEFAULT_GRIDEARCH_CONFIG, use_gpu:bool=True, - storage_pwd:str=None, storage_port:int=None): + storage_pwd:str=None, storage_port:int=None, storage_host:str='localhost'): if gridsearch_config is None: gridsearch_config = DEFAULT_GRIDEARCH_CONFIG @@ -51,7 +51,7 @@ def launch_gridsearch(data_splits_path:str, output_folder:str, gridsearch_config if storage_pwd is not None and storage_port is not None: storage = optuna.storages.JournalStorage(optuna.storages.JournalRedisStorage( - url=f'redis://default:{storage_pwd}@localhost:{storage_port}/opsum' + url=f'redis://default:{storage_pwd}@{storage_host}:{storage_port}/opsum' )) else: storage = None @@ -172,6 +172,7 @@ def get_score(trial, ds, data_splits_path, output_folder, gridsearch_config:dict parser.add_argument('-g', '--use_gpu', type=int, required=False, default=1) parser.add_argument('-spwd', '--storage_pwd', type=str, required=False, default=None) parser.add_argument('-sport', '--storage_port', type=int, required=False, default=None) + parser.add_argument('-shost', '--storage_host', type=str, required=False, default=None) args = parser.parse_args() @@ -181,4 +182,4 @@ def get_score(trial, ds, data_splits_path, output_folder, gridsearch_config:dict gridsearch_config = json.load(open(args.config)) launch_gridsearch(data_splits_path=args.data_splits_path, output_folder=args.output_folder, gridsearch_config=gridsearch_config, - use_gpu=use_gpu, storage_pwd=args.storage_pwd, storage_port=args.storage_port) + use_gpu=use_gpu, storage_pwd=args.storage_pwd, storage_port=args.storage_port, storage_host=args.storage_host)