From 9669f9085d2586be81e4d4dc335b5f930ef7eab0 Mon Sep 17 00:00:00 2001 From: Bora Uyar Date: Wed, 1 May 2024 19:18:59 +0200 Subject: [PATCH 1/6] update HPO class to do cross-validation to pick best model; build a final model using best params --- flexynesis/main.py | 141 +++++++++++++++++++++++++++++---------------- 1 file changed, 90 insertions(+), 51 deletions(-) diff --git a/flexynesis/main.py b/flexynesis/main.py index a9c993c..0e55e6e 100644 --- a/flexynesis/main.py +++ b/flexynesis/main.py @@ -37,7 +37,7 @@ class HyperparameterTuning: config_name: Name of the configuration for tuning parameters. n_iter: Number of iterations for the tuning process. plot_losses: Boolean flag to plot losses during training. - val_size: Validation set size as a fraction of the dataset. + cv_splits: Number of cross-validation folds. use_loss_weighting: Flag to use loss weighting during training. early_stop_patience: Number of epochs to wait for improvement before stopping. device_type: Str (cpu, gpu) @@ -50,7 +50,7 @@ class HyperparameterTuning: def __init__(self, dataset, model_class, config_name, target_variables, batch_variables = None, surv_event_var = None, surv_time_var = None, n_iter = 10, config_path = None, plot_losses = False, - val_size = 0.2, use_loss_weighting = True, early_stop_patience = -1, + cv_splits = 5, use_loss_weighting = True, early_stop_patience = -1, device_type = None, gnn_conv_type = None, input_layers = None, output_layers = None): self.dataset = dataset @@ -65,7 +65,7 @@ def __init__(self, dataset, model_class, config_name, target_variables, self.config_name = config_name self.n_iter = n_iter self.plot_losses = plot_losses # Whether to show live loss plots (useful in interactive mode) - self.val_size = val_size + self.n_splits = cv_splits self.progress_bar = RichProgressBar( theme = RichProgressBarTheme( progress_bar = 'green1', @@ -93,7 +93,7 @@ def __init__(self, dataset, model_class, config_name, target_variables, raise ValueError(f"'{self.config_name}' not found in the default config.") def get_batch_space(self, min_size = 16, max_size = 256): - m = int(np.log2(len(self.dataset) * (1 - self.val_size))) + m = int(np.log2(len(self.dataset) * 0.8)) st = int(np.log2(min_size)) end = int(np.log2(max_size)) if m < end: @@ -101,78 +101,117 @@ def get_batch_space(self, min_size = 16, max_size = 256): s = Categorical([np.power(2, x) for x in range(st, end+1)], name = 'batch_size') return s - def objective(self, params, current_step, total_steps): - - # args common to all model classes - model_args = {"config": params, "dataset": self.dataset, "target_variables": self.target_variables, - "batch_variables": self.batch_variables, "surv_event_var": self.surv_event_var, - "surv_time_var": self.surv_time_var, "val_size": self.val_size, - "use_loss_weighting": self.use_loss_weighting, "device_type": self.device_type} - if self.model_class.__name__ == 'DirectPredGCNN': - model_args["gnn_conv_type"] = self.gnn_conv_type - if self.model_class.__name__ == 'CrossModalPred': - model_args["input_layers"] = self.input_layers - model_args["output_layers"] = self.output_layers - - model = self.model_class(**model_args) - print(params) - + def setup_trainer(self, params, current_step, total_steps, full_train = False): + # Configure callbacks and trainer for the current fold mycallbacks = [self.progress_bar] - # for interactive usage, only show loss plots if self.plot_losses: - mycallbacks = [LiveLossPlot(hyperparams=params, current_step=current_step, total_steps=total_steps)] - - if self.early_stop_patience > 0: - mycallbacks.append(self.init_early_stopping()) - - trainer = pl.Trainer(max_epochs=int(params['epochs']), log_every_n_steps=5, - callbacks = mycallbacks, default_root_dir="./", logger=False, - enable_checkpointing=False, - devices=1, accelerator=self.device_type) - - # Create a new Trainer instance for validation, ensuring single-device processing - validation_trainer = pl.Trainer( - logger=False, + mycallbacks.append(LiveLossPlot(hyperparams=params, current_step=current_step, total_steps=total_steps)) + # when training on a full dataset; no cross-validation or no validation splits; + # we don't do early stopping + early_stop_callback = None + if self.early_stop_patience > 0 and full_train == False: + early_stop_callback = self.init_early_stopping() + mycallbacks.append(early_stop_callback) + + trainer = pl.Trainer( + max_epochs=int(params['epochs']), + log_every_n_steps=5, + callbacks=mycallbacks, + default_root_dir="./", + logger=False, enable_checkpointing=False, - devices=1, # make sure to a single device for validation + devices=1, accelerator=self.device_type ) + return trainer, early_stop_callback + + def objective(self, params, current_step, total_steps, full_train = False): + # Unpack or construct specific model arguments + model_args = { + "config": params, + "dataset": self.dataset, + "target_variables": self.target_variables, + "batch_variables": self.batch_variables, + "surv_event_var": self.surv_event_var, + "surv_time_var": self.surv_time_var, + "use_loss_weighting": self.use_loss_weighting, + "device_type": self.device_type, + } - try: - # Train the model - trainer.fit(model) - # Validate the model - val_loss = validation_trainer.validate(model)[0]['val_loss'] - except ValueError as e: - print(str(e)) - val_loss = float('inf') # or some other value indicating failure - return val_loss, model + if self.model_class.__name__ == 'DirectPredGCNN': + model_args['gnn_conv_type'] = self.gnn_conv_type + if self.model_class.__name__ == 'CrossModalPred': + model_args['input_layers'] = self.input_layers + model_args['output_layers'] = self.output_layers + + if full_train: + # Train on the full dataset + full_loader = DataLoader(self.dataset, batch_size=int(params['batch_size']), shuffle=True) + model = self.model_class(**model_args) + trainer, _ = self.setup_trainer(params, current_step, total_steps, full_train = True) + trainer.fit(model, train_dataloaders=full_loader) + return model # Return the trained model + + else: + # Perform k-fold cross-validation + validation_losses = [] + kf = KFold(n_splits=self.n_splits, shuffle=True) + i = 1 + epochs = [] # number of epochs per fold + for train_index, val_index in kf.split(self.dataset): + print(f"[INFO] training cross-validation fold {i}") + train_subset = torch.utils.data.Subset(self.dataset, train_index) + val_subset = torch.utils.data.Subset(self.dataset, val_index) + train_loader = DataLoader(train_subset, batch_size=int(params['batch_size']), shuffle=True) + val_loader = DataLoader(val_subset, batch_size=int(params['batch_size'])) + + model = self.model_class(**model_args) + trainer, early_stop_callback = self.setup_trainer(params, current_step, total_steps) + trainer.fit(model, train_dataloaders=train_loader, val_dataloaders=val_loader) + if early_stop_callback.stopped_epoch: + epochs.append(early_stop_callback.stopped_epoch) + else: + epochs.append(int(params['epochs'])) + validation_result = trainer.validate(model, dataloaders=val_loader) + val_loss = validation_result[0]['val_loss'] + validation_losses.append(val_loss) + i += 1 + + # Calculate average validation loss across all folds + avg_val_loss = np.mean(validation_losses) + avg_epochs = int(np.mean(epochs)) + return avg_val_loss, avg_epochs def perform_tuning(self): opt = Optimizer(dimensions=self.space, n_initial_points=10, acq_func="gp_hedge", acq_optimizer="auto") best_loss = np.inf best_params = None - best_model = None - + best_epochs = 0 + with tqdm(total=self.n_iter, desc='Tuning Progress') as pbar: for i in range(self.n_iter): np.int = int suggested_params_list = opt.ask() suggested_params_dict = {param.name: value for param, value in zip(self.space, suggested_params_list)} - loss, model = self.objective(suggested_params_dict, current_step=i+1, total_steps=self.n_iter) + loss, avg_epochs = self.objective(suggested_params_dict, current_step=i+1, total_steps=self.n_iter) + print(f"[INFO] average 5-fold cross-validation loss {loss} for params: {suggested_params_dict}") opt.tell(suggested_params_list, loss) if loss < best_loss: best_loss = loss best_params = suggested_params_list - best_model = model + best_epochs = avg_epochs # Print result of each iteration pbar.set_postfix({'Iteration': i+1, 'Best Loss': best_loss}) pbar.update(1) # convert best params to dict best_params_dict = {param.name: value for param, value in zip(self.space, best_params)} + best_params_dict['epochs'] = avg_epochs + # build a final model based on best params + print(f"[INFO] Building a final model using best params: {best_params_dict}") + best_model = self.objective(best_params_dict, current_step=0, total_steps=1, full_train=True) return best_model, best_params_dict def init_early_stopping(self): @@ -180,7 +219,7 @@ def init_early_stopping(self): return EarlyStopping( monitor='val_loss', patience=self.early_stop_patience, - verbose=True, + verbose=False, mode='min' ) @@ -296,9 +335,9 @@ def run_experiments(self): ) trainer = pl.Trainer(max_epochs=self.max_epoch, devices=1, accelerator='auto', logger=False, enable_checkpointing=False, enable_progress_bar = False, enable_model_summary=False, callbacks=[early_stopping]) - trainer.fit(self) + trainer.fit(self, train_dataloaders=self.train_dataloader(), val_dataloaders=self.val_dataloader()) stopped_epoch = early_stopping.stopped_epoch - val_loss = trainer.validate(self.model, verbose = False) + val_loss = trainer.validate(self, dataloaders = self.val_dataloader(), verbose = False) fold_losses.append(val_loss[0]['val_loss']) # Adjust based on your validation output format epochs.append(stopped_epoch) #print(f"[INFO] Finetuning ... training fold: {fold}, learning rate: {lr}, val_loss: {val_loss}, freeze {config}") From 30bf471f9ebb4c9fff83b4e261eea9df085e0013 Mon Sep 17 00:00:00 2001 From: Bora Uyar Date: Wed, 1 May 2024 19:19:16 +0200 Subject: [PATCH 2/6] remove data loaders from class definition --- flexynesis/models/direct_pred.py | 24 +----------------------- 1 file changed, 1 insertion(+), 23 deletions(-) diff --git a/flexynesis/models/direct_pred.py b/flexynesis/models/direct_pred.py index 845418d..5586d95 100644 --- a/flexynesis/models/direct_pred.py +++ b/flexynesis/models/direct_pred.py @@ -16,7 +16,7 @@ class DirectPred(pl.LightningModule): def __init__(self, config, dataset, target_variables, batch_variables = None, - surv_event_var = None, surv_time_var = None, val_size = 0.2, use_loss_weighting = True, + surv_event_var = None, surv_time_var = None, use_loss_weighting = True, device_type = None): super(DirectPred, self).__init__() self.config = config @@ -29,14 +29,10 @@ def __init__(self, config, dataset, target_variables, batch_variables = None, self.target_variables = self.target_variables + [self.surv_event_var] self.batch_variables = batch_variables self.variables = self.target_variables + batch_variables if batch_variables else self.target_variables - self.val_size = val_size self.feature_importances = {} self.use_loss_weighting = use_loss_weighting self.device_type = device_type - # define data loaders - self.prepare_data_loaders(dataset) - if self.use_loss_weighting: # Initialize log variance parameters for uncertainty weighting self.log_vars = nn.ParameterDict() @@ -196,24 +192,6 @@ def validation_step(self, val_batch, batch_idx, log = True): self.log_dict(losses, on_step=False, on_epoch=True, prog_bar=True) return total_loss - def prepare_data_loaders(self, dataset): - # Split the dataset - train_size = int(len(dataset) * (1 - self.val_size)) - val_size = len(dataset) - train_size - dat_train, dat_val = random_split(dataset, [train_size, val_size], generator=torch.Generator().manual_seed(42)) - - # Create data loaders - self.train_loader = DataLoader(dat_train, batch_size=int(self.config['batch_size']), - num_workers=0, pin_memory=True, shuffle=True, drop_last=True) - self.val_loader = DataLoader(dat_val, batch_size=int(self.config['batch_size']), - num_workers=0, pin_memory=True, shuffle=False) - - def train_dataloader(self): - return self.train_loader - - def val_dataloader(self): - return self.val_loader - def predict(self, dataset): """ Evaluate the DirectPred model on a given dataset. From 684e3a6d7e02bef4a6478d88f1e044816a8a383f Mon Sep 17 00:00:00 2001 From: Bora Uyar Date: Wed, 1 May 2024 19:51:29 +0200 Subject: [PATCH 3/6] remove dataloaders from class --- flexynesis/models/supervised_vae.py | 18 +----------------- 1 file changed, 1 insertion(+), 17 deletions(-) diff --git a/flexynesis/models/supervised_vae.py b/flexynesis/models/supervised_vae.py index 5051ea6..128fe6e 100644 --- a/flexynesis/models/supervised_vae.py +++ b/flexynesis/models/supervised_vae.py @@ -43,7 +43,7 @@ class supervised_vae(pl.LightningModule): """ def __init__(self, config, dataset, target_variables, batch_variables = None, - surv_event_var = None, surv_time_var = None, val_size = 0.2, use_loss_weighting = True, + surv_event_var = None, surv_time_var = None, use_loss_weighting = True, device_type = None): super(supervised_vae, self).__init__() self.config = config @@ -57,9 +57,6 @@ def __init__(self, config, dataset, target_variables, batch_variables = None, self.target_variables = self.target_variables + [self.surv_event_var] self.batch_variables = batch_variables self.variables = self.target_variables + batch_variables if batch_variables else self.target_variables - self.val_size = val_size - - self.dat_train, self.dat_val = self.prepare_data() self.feature_importances = {} # sometimes the model may have exploding/vanishing gradients leading to NaN values @@ -279,19 +276,6 @@ def validation_step(self, val_batch, batch_idx, log = True): self.log_dict(losses, on_step=False, on_epoch=True, prog_bar=True) return total_loss - def prepare_data(self): - lt = int(len(self.dataset)*(1-self.val_size)) - lv = len(self.dataset)-lt - dat_train, dat_val = random_split(self.dataset, [lt, lv], - generator=torch.Generator().manual_seed(42)) - return dat_train, dat_val - - def train_dataloader(self): - return DataLoader(self.dat_train, batch_size=int(self.config['batch_size']), num_workers=0, pin_memory=True, shuffle=True, drop_last=True) - - def val_dataloader(self): - return DataLoader(self.dat_val, batch_size=int(self.config['batch_size']), num_workers=0, pin_memory=True, shuffle=False) - def transform(self, dataset): """ Transform the input dataset to latent representation. From 1d56da9f6424d1b4b72d4113204d9c743394deed Mon Sep 17 00:00:00 2001 From: Bora Uyar Date: Wed, 1 May 2024 20:08:07 +0200 Subject: [PATCH 4/6] remove data loaders from class definition --- flexynesis/models/triplet_encoder.py | 25 ++----------------------- 1 file changed, 2 insertions(+), 23 deletions(-) diff --git a/flexynesis/models/triplet_encoder.py b/flexynesis/models/triplet_encoder.py index 846a461..a745497 100644 --- a/flexynesis/models/triplet_encoder.py +++ b/flexynesis/models/triplet_encoder.py @@ -22,7 +22,7 @@ class MultiTripletNetwork(pl.LightningModule): """ """ def __init__(self, config, dataset, target_variables, batch_variables = None, - surv_event_var = None, surv_time_var = None, val_size = 0.2, use_loss_weighting = True, + surv_event_var = None, surv_time_var = None, use_loss_weighting = True, device_type = None): """ Initialize the MultiTripletNetwork with the given parameters. @@ -42,7 +42,6 @@ def __init__(self, config, dataset, target_variables, batch_variables = None, self.target_variables = self.target_variables + [self.surv_event_var] self.batch_variables = batch_variables self.variables = self.target_variables + batch_variables if batch_variables else self.target_variables - self.val_size = val_size self.ann = dataset.ann self.variable_types = dataset.variable_types self.feature_importances = {} @@ -61,8 +60,6 @@ def __init__(self, config, dataset, target_variables, batch_variables = None, for loss_type in itertools.chain(self.variables, ['triplet_loss']): self.log_vars[loss_type] = nn.Parameter(torch.zeros(1)) - self.prepare_data_loaders(dataset) - self.layers = list(dataset.dat.keys()) self.input_dims = [len(dataset.features[self.layers[i]]) for i in range(len(self.layers))] @@ -226,25 +223,7 @@ def validation_step(self, val_batch, batch_idx, log = True): if log: self.log_dict(losses, on_step=False, on_epoch=True, prog_bar=True) return total_loss - - def prepare_data_loaders(self, dataset): - # create train/validation splits and convert TripletMultiOmicDataset format - triplet_dataset = TripletMultiOmicDataset(dataset, self.main_var) - lt = int(len(triplet_dataset)*(1-self.val_size)) - lv = len(triplet_dataset)-lt - dat_train, dat_val = random_split(triplet_dataset, [lt, lv], - generator=torch.Generator().manual_seed(42)) - self.train_loader = DataLoader(dat_train, batch_size=int(self.config['batch_size']), - num_workers=0, pin_memory=True, shuffle=True, drop_last=True) - self.val_loader = DataLoader(dat_val, batch_size=int(self.config['batch_size']), - num_workers=0, pin_memory=True, shuffle=False) - - def train_dataloader(self): - return self.train_loader - - def val_dataloader(self): - return self.val_loader - + # dataset: MultiOmicDataset def transform(self, dataset): """ From 7a10cbc23ac2f07e9af913f9c4b71c09384a7de4 Mon Sep 17 00:00:00 2001 From: Bora Uyar Date: Wed, 1 May 2024 20:08:32 +0200 Subject: [PATCH 5/6] account for data structure for data loaders --- flexynesis/main.py | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/flexynesis/main.py b/flexynesis/main.py index 0e55e6e..962cda7 100644 --- a/flexynesis/main.py +++ b/flexynesis/main.py @@ -53,7 +53,8 @@ def __init__(self, dataset, model_class, config_name, target_variables, cv_splits = 5, use_loss_weighting = True, early_stop_patience = -1, device_type = None, gnn_conv_type = None, input_layers = None, output_layers = None): - self.dataset = dataset + self.dataset = dataset # dataset for model initiation + self.loader_dataset = dataset # dataset for defining data loaders (this can be model specific) self.model_class = model_class self.target_variables = target_variables self.device_type = device_type @@ -77,6 +78,9 @@ def __init__(self, dataset, model_class, config_name, target_variables, self.input_layers = input_layers self.output_layers = output_layers + if self.model_class.__name__ == 'MultiTripletNetwork': + self.loader_dataset = TripletMultiOmicDataset(self.dataset, self.target_variables[0]) + # If config_path is provided, use it if config_path: external_config = self.load_and_convert_config(config_path) @@ -146,7 +150,7 @@ def objective(self, params, current_step, total_steps, full_train = False): if full_train: # Train on the full dataset - full_loader = DataLoader(self.dataset, batch_size=int(params['batch_size']), shuffle=True) + full_loader = DataLoader(self.loader_dataset, batch_size=int(params['batch_size']), shuffle=True) model = self.model_class(**model_args) trainer, _ = self.setup_trainer(params, current_step, total_steps, full_train = True) trainer.fit(model, train_dataloaders=full_loader) @@ -158,10 +162,10 @@ def objective(self, params, current_step, total_steps, full_train = False): kf = KFold(n_splits=self.n_splits, shuffle=True) i = 1 epochs = [] # number of epochs per fold - for train_index, val_index in kf.split(self.dataset): + for train_index, val_index in kf.split(self.loader_dataset): print(f"[INFO] training cross-validation fold {i}") - train_subset = torch.utils.data.Subset(self.dataset, train_index) - val_subset = torch.utils.data.Subset(self.dataset, val_index) + train_subset = torch.utils.data.Subset(self.loader_dataset, train_index) + val_subset = torch.utils.data.Subset(self.loader_dataset, val_index) train_loader = DataLoader(train_subset, batch_size=int(params['batch_size']), shuffle=True) val_loader = DataLoader(val_subset, batch_size=int(params['batch_size'])) From 5576b632664ad7b7b0afc89be8d0555d726c10ba Mon Sep 17 00:00:00 2001 From: Bora Uyar Date: Wed, 1 May 2024 20:15:58 +0200 Subject: [PATCH 6/6] adapt crossmodal pred to finetuning and crossvalidation features --- flexynesis/models/crossmodal_pred.py | 35 +++++++--------------------- 1 file changed, 8 insertions(+), 27 deletions(-) diff --git a/flexynesis/models/crossmodal_pred.py b/flexynesis/models/crossmodal_pred.py index 0554b8f..dc86859 100644 --- a/flexynesis/models/crossmodal_pred.py +++ b/flexynesis/models/crossmodal_pred.py @@ -30,7 +30,7 @@ class CrossModalPred(pl.LightningModule): def __init__(self, config, dataset, target_variables = None, batch_variables = None, surv_event_var = None, surv_time_var = None, input_layers = None, output_layers = None, - val_size = 0.2, use_loss_weighting = True, + use_loss_weighting = True, device_type = None): super(CrossModalPred, self).__init__() self.config = config @@ -49,10 +49,6 @@ def __init__(self, config, dataset, target_variables = None, batch_variables = self.input_layers = input_layers if input_layers else list(dataset.dat.keys()) self.output_layers = output_layers if output_layers else list(dataset.dat.keys()) - self.val_size = val_size - - self.prepare_data_loaders(dataset) - self.feature_importances = {} self.device_type = device_type @@ -209,7 +205,7 @@ def compute_total_loss(self, losses): return total_loss - def training_step(self, train_batch, batch_idx): + def training_step(self, train_batch, batch_idx, log = True): dat, y_dict = train_batch # get input omics modalities and encode them; decode them to output layers @@ -239,10 +235,11 @@ def training_step(self, train_batch, batch_idx): total_loss = self.compute_total_loss(losses) # add total loss for logging losses['train_loss'] = total_loss - self.log_dict(losses, on_step=False, on_epoch=True, prog_bar=True) + if log: + self.log_dict(losses, on_step=False, on_epoch=True, prog_bar=True) return total_loss - def validation_step(self, val_batch, batch_idx): + def validation_step(self, val_batch, batch_idx, log = True): dat, y_dict = val_batch # get input omics modalities and encode them @@ -270,26 +267,10 @@ def validation_step(self, val_batch, batch_idx): total_loss = sum(losses.values()) losses['val_loss'] = total_loss - self.log_dict(losses, on_step=False, on_epoch=True, prog_bar=True) + if log: + self.log_dict(losses, on_step=False, on_epoch=True, prog_bar=True) return total_loss - - def prepare_data_loaders(self, dataset): - - lt = int(len(dataset)*(1-self.val_size)) - lv = len(dataset)-lt - dat_train, dat_val = random_split(dataset, [lt, lv], - generator=torch.Generator().manual_seed(42)) - self.train_loader = DataLoader(dat_train, batch_size=int(self.config['batch_size']), - num_workers=0, pin_memory=True, shuffle=True, drop_last=True) - self.val_loader = DataLoader(dat_val, batch_size=int(self.config['batch_size']), - num_workers=0, pin_memory=True, shuffle=False) - - def train_dataloader(self): - return self.train_loader - - def val_dataloader(self): - return self.val_loader - + def transform(self, dataset): """ Transform the input dataset to latent representation.