Skip to content

Commit

Permalink
add gnn_conv_type argument to allow the user to choose type of convol…
Browse files Browse the repository at this point in the history
…ution from CLI; make sanity checks
  • Loading branch information
borauyar committed Mar 4, 2024
1 parent 91f90d5 commit d39bbb9
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 6 deletions.
26 changes: 22 additions & 4 deletions flexynesis/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,10 @@ def main():
formatter_class=argparse.ArgumentDefaultsHelpFormatter)

parser.add_argument("--data_path", help="(Required) Path to the folder with train/test data files", type=str, required = True)
parser.add_argument("--model_class", help="(Required) The kind of model class to instantiate", type=str, choices=["DirectPred", "DirectPredGCNN", "supervised_vae", "MultiTripletNetwork"], required = True)
parser.add_argument("--model_class", help="(Required) The kind of model class to instantiate", type=str,
choices=["DirectPred", "DirectPredGCNN", "supervised_vae", "MultiTripletNetwork"], required = True)
parser.add_argument("--gnn_conv_type", help="If model_class is set to DirectPredGCNN, choose which graph convolution type to use", type=str,
choices=["GC", "GCN", "GAT", "GIN", "SAGE", "CHEB"])
parser.add_argument("--target_variables",
help="(Optional if survival variables are not set to None)."
"Which variables in 'clin.csv' to use for predictions, comma-separated if multiple",
Expand Down Expand Up @@ -56,7 +59,7 @@ def main():
# 2. Check for required variables for model classes
if args.model_class != "supervised_vae":
if not any([args.target_variables, args.surv_event_var, args.batch_variables]):
parser.error(''.join(["When selecting a model other than 'supervised_vae',"
parser.error(''.join(["When selecting a model other than 'supervised_vae',",
"you must provide at least one of --target_variables, ",
"survival variables (--surv_event_var and --surv_time_var)",
"or --batch_variables."]))
Expand All @@ -82,7 +85,21 @@ def main():
else:
device_type = 'cpu'
torch.set_num_threads(args.threads)


# 5. check GNN arguments
if args.model_class == 'DirectPredGCNN':
if not args.gnn_conv_type:
warning_message = "\n".join([
"\n\n!!! When running DirectPredGCNN, a convolution type can be set",
"with the --gnn_conv_type flag. See `flexynesis -h` for full set of options.",
"Falling back on the default convolution type: GC !!!\n\n"
])
warnings.warn(warning_message)
time.sleep(3) #wait a bit to capture user's attention to the warning
gnn_conv_type = 'GC'
else:
gnn_conv_type = args.gnn_conv_type

# Validate paths
if not os.path.exists(args.data_path):
raise FileNotFoundError(f"Input --data_path doesn't exist at:", {args.data_path})
Expand Down Expand Up @@ -144,7 +161,8 @@ class AvailableModels(NamedTuple):
n_iter=int(args.hpo_iter),
use_loss_weighting = args.use_loss_weighting == 'True',
early_stop_patience = int(args.early_stop_patience),
device_type = device_type)
device_type = device_type,
gnn_conv_type = gnn_conv_type)

# do a hyperparameter search training multiple models and get the best_configuration
model, best_params = tuner.perform_tuning()
Expand Down
6 changes: 4 additions & 2 deletions flexynesis/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def __init__(self, dataset, model_class, config_name, target_variables,
batch_variables = None, surv_event_var = None, surv_time_var = None,
n_iter = 10, config_path = None, plot_losses = False,
val_size = 0.2, use_loss_weighting = True, early_stop_patience = -1,
device_type = None):
device_type = None, gnn_conv_type = None):
self.dataset = dataset
self.model_class = model_class
self.target_variables = target_variables
Expand All @@ -71,6 +71,7 @@ def __init__(self, dataset, model_class, config_name, target_variables,
progress_bar_finished='red'))
self.early_stop_patience = early_stop_patience
self.use_loss_weighting = use_loss_weighting
self.gnn_conv_type = gnn_conv_type

# If config_path is provided, use it
if config_path:
Expand All @@ -93,7 +94,8 @@ def objective(self, params, current_step, total_steps):
surv_time_var = self.surv_time_var,
val_size = self.val_size,
use_loss_weighting = self.use_loss_weighting,
device_type = self.device_type)
device_type = self.device_type,
gnn_conv_type = self.gnn_conv_type)
print(params)

mycallbacks = [self.progress_bar]
Expand Down

0 comments on commit d39bbb9

Please sign in to comment.