From d39bbb9f43c81602889e61e53881d5d7db63541a Mon Sep 17 00:00:00 2001 From: Bora Uyar Date: Mon, 4 Mar 2024 16:24:22 +0100 Subject: [PATCH] add gnn_conv_type argument to allow the user to choose type of convolution from CLI; make sanity checks --- flexynesis/__main__.py | 26 ++++++++++++++++++++++---- flexynesis/main.py | 6 ++++-- 2 files changed, 26 insertions(+), 6 deletions(-) diff --git a/flexynesis/__main__.py b/flexynesis/__main__.py index 9404a14..c13c8fd 100644 --- a/flexynesis/__main__.py +++ b/flexynesis/__main__.py @@ -14,7 +14,10 @@ def main(): formatter_class=argparse.ArgumentDefaultsHelpFormatter) parser.add_argument("--data_path", help="(Required) Path to the folder with train/test data files", type=str, required = True) - parser.add_argument("--model_class", help="(Required) The kind of model class to instantiate", type=str, choices=["DirectPred", "DirectPredGCNN", "supervised_vae", "MultiTripletNetwork"], required = True) + parser.add_argument("--model_class", help="(Required) The kind of model class to instantiate", type=str, + choices=["DirectPred", "DirectPredGCNN", "supervised_vae", "MultiTripletNetwork"], required = True) + parser.add_argument("--gnn_conv_type", help="If model_class is set to DirectPredGCNN, choose which graph convolution type to use", type=str, + choices=["GC", "GCN", "GAT", "GIN", "SAGE", "CHEB"]) parser.add_argument("--target_variables", help="(Optional if survival variables are not set to None)." "Which variables in 'clin.csv' to use for predictions, comma-separated if multiple", @@ -56,7 +59,7 @@ def main(): # 2. Check for required variables for model classes if args.model_class != "supervised_vae": if not any([args.target_variables, args.surv_event_var, args.batch_variables]): - parser.error(''.join(["When selecting a model other than 'supervised_vae'," + parser.error(''.join(["When selecting a model other than 'supervised_vae',", "you must provide at least one of --target_variables, ", "survival variables (--surv_event_var and --surv_time_var)", "or --batch_variables."])) @@ -82,7 +85,21 @@ def main(): else: device_type = 'cpu' torch.set_num_threads(args.threads) - + + # 5. check GNN arguments + if args.model_class == 'DirectPredGCNN': + if not args.gnn_conv_type: + warning_message = "\n".join([ + "\n\n!!! When running DirectPredGCNN, a convolution type can be set", + "with the --gnn_conv_type flag. See `flexynesis -h` for full set of options.", + "Falling back on the default convolution type: GC !!!\n\n" + ]) + warnings.warn(warning_message) + time.sleep(3) #wait a bit to capture user's attention to the warning + gnn_conv_type = 'GC' + else: + gnn_conv_type = args.gnn_conv_type + # Validate paths if not os.path.exists(args.data_path): raise FileNotFoundError(f"Input --data_path doesn't exist at:", {args.data_path}) @@ -144,7 +161,8 @@ class AvailableModels(NamedTuple): n_iter=int(args.hpo_iter), use_loss_weighting = args.use_loss_weighting == 'True', early_stop_patience = int(args.early_stop_patience), - device_type = device_type) + device_type = device_type, + gnn_conv_type = gnn_conv_type) # do a hyperparameter search training multiple models and get the best_configuration model, best_params = tuner.perform_tuning() diff --git a/flexynesis/main.py b/flexynesis/main.py index 1217753..ac4f746 100644 --- a/flexynesis/main.py +++ b/flexynesis/main.py @@ -50,7 +50,7 @@ def __init__(self, dataset, model_class, config_name, target_variables, batch_variables = None, surv_event_var = None, surv_time_var = None, n_iter = 10, config_path = None, plot_losses = False, val_size = 0.2, use_loss_weighting = True, early_stop_patience = -1, - device_type = None): + device_type = None, gnn_conv_type = None): self.dataset = dataset self.model_class = model_class self.target_variables = target_variables @@ -71,6 +71,7 @@ def __init__(self, dataset, model_class, config_name, target_variables, progress_bar_finished='red')) self.early_stop_patience = early_stop_patience self.use_loss_weighting = use_loss_weighting + self.gnn_conv_type = gnn_conv_type # If config_path is provided, use it if config_path: @@ -93,7 +94,8 @@ def objective(self, params, current_step, total_steps): surv_time_var = self.surv_time_var, val_size = self.val_size, use_loss_weighting = self.use_loss_weighting, - device_type = self.device_type) + device_type = self.device_type, + gnn_conv_type = self.gnn_conv_type) print(params) mycallbacks = [self.progress_bar]