Skip to content

Commit

Permalink
Fixes #75; cross-validation is optional; default training uses single…
Browse files Browse the repository at this point in the history
…-split train/val
  • Loading branch information
borauyar committed May 13, 2024
1 parent 13a42e5 commit 1e1fc7c
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 16 deletions.
3 changes: 3 additions & 0 deletions flexynesis/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,8 @@ def main():
parser.add_argument("--prefix", help="Job prefix to use for output files", type=str, default = 'job')
parser.add_argument("--log_transform", help="whether to apply log-transformation to input data matrices", type=str, choices=['True', 'False'], default = 'False')
parser.add_argument("--early_stop_patience", help="How many epochs to wait when no improvements in validation loss is observed (default: 10; set to -1 to disable early stopping)", type=int, default = 10)
parser.add_argument("--use_cv", action="store_true",
help="(Optional) If set, the a 5-fold cross-validation training will be done. Otherwise, a single trainign on 80% of the dataset is done.")
parser.add_argument("--use_loss_weighting", help="whether to apply loss-balancing using uncertainty weights method", type=str, choices=['True', 'False'], default = 'True')
parser.add_argument("--evaluate_baseline_performance", help="whether to run Random Forest + SVMs to see the performance of off-the-shelf tools on the same dataset", type=str, choices=['True', 'False'], default = 'True')
parser.add_argument("--threads", help="(Optional) How many threads to use when using CPU (default: 4)", type=int, default = 4)
Expand Down Expand Up @@ -208,6 +210,7 @@ class AvailableModels(NamedTuple):
config_path = args.config_path,
n_iter=int(args.hpo_iter),
use_loss_weighting = args.use_loss_weighting == 'True',
use_cv = args.use_cv,
early_stop_patience = int(args.early_stop_patience),
device_type = device_type,
gnn_conv_type = gnn_conv_type,
Expand Down
52 changes: 36 additions & 16 deletions flexynesis/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,8 @@ class HyperparameterTuning:
def __init__(self, dataset, model_class, config_name, target_variables,
batch_variables = None, surv_event_var = None, surv_time_var = None,
n_iter = 10, config_path = None, plot_losses = False,
cv_splits = 5, use_loss_weighting = True, early_stop_patience = -1,
val_size = 0.2, use_cv = False, cv_splits = 5,
use_loss_weighting = True, early_stop_patience = -1,
device_type = None, gnn_conv_type = None,
input_layers = None, output_layers = None):
self.dataset = dataset # dataset for model initiation
Expand All @@ -66,6 +67,8 @@ def __init__(self, dataset, model_class, config_name, target_variables,
self.config_name = config_name
self.n_iter = n_iter
self.plot_losses = plot_losses # Whether to show live loss plots (useful in interactive mode)
self.val_size = val_size
self.use_cv = use_cv
self.n_splits = cv_splits
self.progress_bar = RichProgressBar(
theme = RichProgressBarTheme(
Expand Down Expand Up @@ -150,24 +153,35 @@ def objective(self, params, current_step, total_steps, full_train = False):

if full_train:
# Train on the full dataset
full_loader = DataLoader(self.loader_dataset, batch_size=int(params['batch_size']), shuffle=True)
full_loader = DataLoader(self.loader_dataset, batch_size=int(params['batch_size']),
shuffle=True, pin_memory=True, drop_last=True)
model = self.model_class(**model_args)
trainer, _ = self.setup_trainer(params, current_step, total_steps, full_train = True)
trainer.fit(model, train_dataloaders=full_loader)
return model # Return the trained model

else:
# Perform k-fold cross-validation
validation_losses = []
kf = KFold(n_splits=self.n_splits, shuffle=True)
i = 1
epochs = [] # number of epochs per fold
for train_index, val_index in kf.split(self.loader_dataset):
print(f"[INFO] training cross-validation fold {i}")
epochs = []

if self.use_cv: # if the user asks for cross-validation
kf = KFold(n_splits=self.n_splits, shuffle=True)
split_iterator = kf.split(self.loader_dataset)
else: # otherwise do a single train/validation split
# Compute the number of samples for training based on the ratio
num_val = int(len(self.loader_dataset) * self.val_size)
num_train = len(self.loader_dataset) - num_val
train_subset, val_subset = random_split(self.loader_dataset, [num_train, num_val])
# create single split format similar to KFold
split_iterator = [(list(train_subset.indices), list(val_subset.indices))]
i = 1
model = None # save the model if not using cross-validation
for train_index, val_index in split_iterator:
print(f"[INFO] {'training cross-validation fold' if self.use_cv else 'training validation split'} {i}")
train_subset = torch.utils.data.Subset(self.loader_dataset, train_index)
val_subset = torch.utils.data.Subset(self.loader_dataset, val_index)
train_loader = DataLoader(train_subset, batch_size=int(params['batch_size']), shuffle=True)
val_loader = DataLoader(val_subset, batch_size=int(params['batch_size']))
train_loader = DataLoader(train_subset, batch_size=int(params['batch_size']), pin_memory=True, shuffle=True, drop_last=True)
val_loader = DataLoader(val_subset, batch_size=int(params['batch_size']), pin_memory=True, shuffle=False)

model = self.model_class(**model_args)
trainer, early_stop_callback = self.setup_trainer(params, current_step, total_steps)
Expand All @@ -180,42 +194,48 @@ def objective(self, params, current_step, total_steps, full_train = False):
val_loss = validation_result[0]['val_loss']
validation_losses.append(val_loss)
i += 1
if not self.use_cv:
model = model

# Calculate average validation loss across all folds
avg_val_loss = np.mean(validation_losses)
avg_epochs = int(np.mean(epochs))
return avg_val_loss, avg_epochs
return avg_val_loss, avg_epochs, model

def perform_tuning(self):
opt = Optimizer(dimensions=self.space, n_initial_points=10, acq_func="gp_hedge", acq_optimizer="auto")

best_loss = np.inf
best_params = None
best_epochs = 0
best_model = None

with tqdm(total=self.n_iter, desc='Tuning Progress') as pbar:
for i in range(self.n_iter):
np.int = int
suggested_params_list = opt.ask()
suggested_params_dict = {param.name: value for param, value in zip(self.space, suggested_params_list)}
loss, avg_epochs = self.objective(suggested_params_dict, current_step=i+1, total_steps=self.n_iter)
print(f"[INFO] average 5-fold cross-validation loss {loss} for params: {suggested_params_dict}")
loss, avg_epochs, model = self.objective(suggested_params_dict, current_step=i+1, total_steps=self.n_iter)
if self.use_cv:
print(f"[INFO] average 5-fold cross-validation loss {loss} for params: {suggested_params_dict}")
opt.tell(suggested_params_list, loss)

if loss < best_loss:
best_loss = loss
best_params = suggested_params_list
best_epochs = avg_epochs
best_model = model

# Print result of each iteration
pbar.set_postfix({'Iteration': i+1, 'Best Loss': best_loss})
pbar.update(1)
# convert best params to dict
best_params_dict = {param.name: value for param, value in zip(self.space, best_params)}
best_params_dict['epochs'] = avg_epochs
# build a final model based on best params
print(f"[INFO] Building a final model using best params: {best_params_dict}")
best_model = self.objective(best_params_dict, current_step=0, total_steps=1, full_train=True)
if self.use_cv:
# build a final model based on best params if a cross-validation
print(f"[INFO] Building a final model using best params: {best_params_dict}")
best_model = self.objective(best_params_dict, current_step=0, total_steps=1, full_train=True)
return best_model, best_params_dict

def init_early_stopping(self):
Expand Down

0 comments on commit 1e1fc7c

Please sign in to comment.