Skip to content

Commit

Permalink
Merge pull request #74 from BIMSBbioinfo/crossvalidation
Browse files Browse the repository at this point in the history
Do model training using cross-validation
  • Loading branch information
borauyar authored May 1, 2024
2 parents 6b70d40 + 5576b63 commit 13a42e5
Show file tree
Hide file tree
Showing 5 changed files with 107 additions and 142 deletions.
147 changes: 95 additions & 52 deletions flexynesis/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ class HyperparameterTuning:
config_name: Name of the configuration for tuning parameters.
n_iter: Number of iterations for the tuning process.
plot_losses: Boolean flag to plot losses during training.
val_size: Validation set size as a fraction of the dataset.
cv_splits: Number of cross-validation folds.
use_loss_weighting: Flag to use loss weighting during training.
early_stop_patience: Number of epochs to wait for improvement before stopping.
device_type: Str (cpu, gpu)
Expand All @@ -50,10 +50,11 @@ class HyperparameterTuning:
def __init__(self, dataset, model_class, config_name, target_variables,
batch_variables = None, surv_event_var = None, surv_time_var = None,
n_iter = 10, config_path = None, plot_losses = False,
val_size = 0.2, use_loss_weighting = True, early_stop_patience = -1,
cv_splits = 5, use_loss_weighting = True, early_stop_patience = -1,
device_type = None, gnn_conv_type = None,
input_layers = None, output_layers = None):
self.dataset = dataset
self.dataset = dataset # dataset for model initiation
self.loader_dataset = dataset # dataset for defining data loaders (this can be model specific)
self.model_class = model_class
self.target_variables = target_variables
self.device_type = device_type
Expand All @@ -65,7 +66,7 @@ def __init__(self, dataset, model_class, config_name, target_variables,
self.config_name = config_name
self.n_iter = n_iter
self.plot_losses = plot_losses # Whether to show live loss plots (useful in interactive mode)
self.val_size = val_size
self.n_splits = cv_splits
self.progress_bar = RichProgressBar(
theme = RichProgressBarTheme(
progress_bar = 'green1',
Expand All @@ -77,6 +78,9 @@ def __init__(self, dataset, model_class, config_name, target_variables,
self.input_layers = input_layers
self.output_layers = output_layers

if self.model_class.__name__ == 'MultiTripletNetwork':
self.loader_dataset = TripletMultiOmicDataset(self.dataset, self.target_variables[0])

# If config_path is provided, use it
if config_path:
external_config = self.load_and_convert_config(config_path)
Expand All @@ -93,94 +97,133 @@ def __init__(self, dataset, model_class, config_name, target_variables,
raise ValueError(f"'{self.config_name}' not found in the default config.")

def get_batch_space(self, min_size = 16, max_size = 256):
m = int(np.log2(len(self.dataset) * (1 - self.val_size)))
m = int(np.log2(len(self.dataset) * 0.8))
st = int(np.log2(min_size))
end = int(np.log2(max_size))
if m < end:
end = m
s = Categorical([np.power(2, x) for x in range(st, end+1)], name = 'batch_size')
return s

def objective(self, params, current_step, total_steps):

# args common to all model classes
model_args = {"config": params, "dataset": self.dataset, "target_variables": self.target_variables,
"batch_variables": self.batch_variables, "surv_event_var": self.surv_event_var,
"surv_time_var": self.surv_time_var, "val_size": self.val_size,
"use_loss_weighting": self.use_loss_weighting, "device_type": self.device_type}
if self.model_class.__name__ == 'DirectPredGCNN':
model_args["gnn_conv_type"] = self.gnn_conv_type
if self.model_class.__name__ == 'CrossModalPred':
model_args["input_layers"] = self.input_layers
model_args["output_layers"] = self.output_layers

model = self.model_class(**model_args)
print(params)

def setup_trainer(self, params, current_step, total_steps, full_train = False):
# Configure callbacks and trainer for the current fold
mycallbacks = [self.progress_bar]
# for interactive usage, only show loss plots
if self.plot_losses:
mycallbacks = [LiveLossPlot(hyperparams=params, current_step=current_step, total_steps=total_steps)]

if self.early_stop_patience > 0:
mycallbacks.append(self.init_early_stopping())

trainer = pl.Trainer(max_epochs=int(params['epochs']), log_every_n_steps=5,
callbacks = mycallbacks, default_root_dir="./", logger=False,
enable_checkpointing=False,
devices=1, accelerator=self.device_type)

# Create a new Trainer instance for validation, ensuring single-device processing
validation_trainer = pl.Trainer(
logger=False,
mycallbacks.append(LiveLossPlot(hyperparams=params, current_step=current_step, total_steps=total_steps))
# when training on a full dataset; no cross-validation or no validation splits;
# we don't do early stopping
early_stop_callback = None
if self.early_stop_patience > 0 and full_train == False:
early_stop_callback = self.init_early_stopping()
mycallbacks.append(early_stop_callback)

trainer = pl.Trainer(
max_epochs=int(params['epochs']),
log_every_n_steps=5,
callbacks=mycallbacks,
default_root_dir="./",
logger=False,
enable_checkpointing=False,
devices=1, # make sure to a single device for validation
devices=1,
accelerator=self.device_type
)
return trainer, early_stop_callback

def objective(self, params, current_step, total_steps, full_train = False):
# Unpack or construct specific model arguments
model_args = {
"config": params,
"dataset": self.dataset,
"target_variables": self.target_variables,
"batch_variables": self.batch_variables,
"surv_event_var": self.surv_event_var,
"surv_time_var": self.surv_time_var,
"use_loss_weighting": self.use_loss_weighting,
"device_type": self.device_type,
}

try:
# Train the model
trainer.fit(model)
# Validate the model
val_loss = validation_trainer.validate(model)[0]['val_loss']
except ValueError as e:
print(str(e))
val_loss = float('inf') # or some other value indicating failure
return val_loss, model
if self.model_class.__name__ == 'DirectPredGCNN':
model_args['gnn_conv_type'] = self.gnn_conv_type
if self.model_class.__name__ == 'CrossModalPred':
model_args['input_layers'] = self.input_layers
model_args['output_layers'] = self.output_layers

if full_train:
# Train on the full dataset
full_loader = DataLoader(self.loader_dataset, batch_size=int(params['batch_size']), shuffle=True)
model = self.model_class(**model_args)
trainer, _ = self.setup_trainer(params, current_step, total_steps, full_train = True)
trainer.fit(model, train_dataloaders=full_loader)
return model # Return the trained model

else:
# Perform k-fold cross-validation
validation_losses = []
kf = KFold(n_splits=self.n_splits, shuffle=True)
i = 1
epochs = [] # number of epochs per fold
for train_index, val_index in kf.split(self.loader_dataset):
print(f"[INFO] training cross-validation fold {i}")
train_subset = torch.utils.data.Subset(self.loader_dataset, train_index)
val_subset = torch.utils.data.Subset(self.loader_dataset, val_index)
train_loader = DataLoader(train_subset, batch_size=int(params['batch_size']), shuffle=True)
val_loader = DataLoader(val_subset, batch_size=int(params['batch_size']))

model = self.model_class(**model_args)
trainer, early_stop_callback = self.setup_trainer(params, current_step, total_steps)
trainer.fit(model, train_dataloaders=train_loader, val_dataloaders=val_loader)
if early_stop_callback.stopped_epoch:
epochs.append(early_stop_callback.stopped_epoch)
else:
epochs.append(int(params['epochs']))
validation_result = trainer.validate(model, dataloaders=val_loader)
val_loss = validation_result[0]['val_loss']
validation_losses.append(val_loss)
i += 1

# Calculate average validation loss across all folds
avg_val_loss = np.mean(validation_losses)
avg_epochs = int(np.mean(epochs))
return avg_val_loss, avg_epochs

def perform_tuning(self):
opt = Optimizer(dimensions=self.space, n_initial_points=10, acq_func="gp_hedge", acq_optimizer="auto")

best_loss = np.inf
best_params = None
best_model = None

best_epochs = 0
with tqdm(total=self.n_iter, desc='Tuning Progress') as pbar:
for i in range(self.n_iter):
np.int = int
suggested_params_list = opt.ask()
suggested_params_dict = {param.name: value for param, value in zip(self.space, suggested_params_list)}
loss, model = self.objective(suggested_params_dict, current_step=i+1, total_steps=self.n_iter)
loss, avg_epochs = self.objective(suggested_params_dict, current_step=i+1, total_steps=self.n_iter)
print(f"[INFO] average 5-fold cross-validation loss {loss} for params: {suggested_params_dict}")
opt.tell(suggested_params_list, loss)

if loss < best_loss:
best_loss = loss
best_params = suggested_params_list
best_model = model
best_epochs = avg_epochs

# Print result of each iteration
pbar.set_postfix({'Iteration': i+1, 'Best Loss': best_loss})
pbar.update(1)
# convert best params to dict
best_params_dict = {param.name: value for param, value in zip(self.space, best_params)}
best_params_dict['epochs'] = avg_epochs
# build a final model based on best params
print(f"[INFO] Building a final model using best params: {best_params_dict}")
best_model = self.objective(best_params_dict, current_step=0, total_steps=1, full_train=True)
return best_model, best_params_dict

def init_early_stopping(self):
"""Initialize the early stopping callback."""
return EarlyStopping(
monitor='val_loss',
patience=self.early_stop_patience,
verbose=True,
verbose=False,
mode='min'
)

Expand Down Expand Up @@ -296,9 +339,9 @@ def run_experiments(self):
)
trainer = pl.Trainer(max_epochs=self.max_epoch, devices=1, accelerator='auto', logger=False, enable_checkpointing=False,
enable_progress_bar = False, enable_model_summary=False, callbacks=[early_stopping])
trainer.fit(self)
trainer.fit(self, train_dataloaders=self.train_dataloader(), val_dataloaders=self.val_dataloader())
stopped_epoch = early_stopping.stopped_epoch
val_loss = trainer.validate(self.model, verbose = False)
val_loss = trainer.validate(self, dataloaders = self.val_dataloader(), verbose = False)
fold_losses.append(val_loss[0]['val_loss']) # Adjust based on your validation output format
epochs.append(stopped_epoch)
#print(f"[INFO] Finetuning ... training fold: {fold}, learning rate: {lr}, val_loss: {val_loss}, freeze {config}")
Expand Down
35 changes: 8 additions & 27 deletions flexynesis/models/crossmodal_pred.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ class CrossModalPred(pl.LightningModule):
def __init__(self, config, dataset, target_variables = None, batch_variables = None,
surv_event_var = None, surv_time_var = None,
input_layers = None, output_layers = None,
val_size = 0.2, use_loss_weighting = True,
use_loss_weighting = True,
device_type = None):
super(CrossModalPred, self).__init__()
self.config = config
Expand All @@ -49,10 +49,6 @@ def __init__(self, config, dataset, target_variables = None, batch_variables =
self.input_layers = input_layers if input_layers else list(dataset.dat.keys())
self.output_layers = output_layers if output_layers else list(dataset.dat.keys())

self.val_size = val_size

self.prepare_data_loaders(dataset)

self.feature_importances = {}

self.device_type = device_type
Expand Down Expand Up @@ -209,7 +205,7 @@ def compute_total_loss(self, losses):
return total_loss


def training_step(self, train_batch, batch_idx):
def training_step(self, train_batch, batch_idx, log = True):
dat, y_dict = train_batch

# get input omics modalities and encode them; decode them to output layers
Expand Down Expand Up @@ -239,10 +235,11 @@ def training_step(self, train_batch, batch_idx):
total_loss = self.compute_total_loss(losses)
# add total loss for logging
losses['train_loss'] = total_loss
self.log_dict(losses, on_step=False, on_epoch=True, prog_bar=True)
if log:
self.log_dict(losses, on_step=False, on_epoch=True, prog_bar=True)
return total_loss

def validation_step(self, val_batch, batch_idx):
def validation_step(self, val_batch, batch_idx, log = True):
dat, y_dict = val_batch

# get input omics modalities and encode them
Expand Down Expand Up @@ -270,26 +267,10 @@ def validation_step(self, val_batch, batch_idx):

total_loss = sum(losses.values())
losses['val_loss'] = total_loss
self.log_dict(losses, on_step=False, on_epoch=True, prog_bar=True)
if log:
self.log_dict(losses, on_step=False, on_epoch=True, prog_bar=True)
return total_loss

def prepare_data_loaders(self, dataset):

lt = int(len(dataset)*(1-self.val_size))
lv = len(dataset)-lt
dat_train, dat_val = random_split(dataset, [lt, lv],
generator=torch.Generator().manual_seed(42))
self.train_loader = DataLoader(dat_train, batch_size=int(self.config['batch_size']),
num_workers=0, pin_memory=True, shuffle=True, drop_last=True)
self.val_loader = DataLoader(dat_val, batch_size=int(self.config['batch_size']),
num_workers=0, pin_memory=True, shuffle=False)

def train_dataloader(self):
return self.train_loader

def val_dataloader(self):
return self.val_loader


def transform(self, dataset):
"""
Transform the input dataset to latent representation.
Expand Down
24 changes: 1 addition & 23 deletions flexynesis/models/direct_pred.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

class DirectPred(pl.LightningModule):
def __init__(self, config, dataset, target_variables, batch_variables = None,
surv_event_var = None, surv_time_var = None, val_size = 0.2, use_loss_weighting = True,
surv_event_var = None, surv_time_var = None, use_loss_weighting = True,
device_type = None):
super(DirectPred, self).__init__()
self.config = config
Expand All @@ -29,14 +29,10 @@ def __init__(self, config, dataset, target_variables, batch_variables = None,
self.target_variables = self.target_variables + [self.surv_event_var]
self.batch_variables = batch_variables
self.variables = self.target_variables + batch_variables if batch_variables else self.target_variables
self.val_size = val_size
self.feature_importances = {}
self.use_loss_weighting = use_loss_weighting
self.device_type = device_type

# define data loaders
self.prepare_data_loaders(dataset)

if self.use_loss_weighting:
# Initialize log variance parameters for uncertainty weighting
self.log_vars = nn.ParameterDict()
Expand Down Expand Up @@ -196,24 +192,6 @@ def validation_step(self, val_batch, batch_idx, log = True):
self.log_dict(losses, on_step=False, on_epoch=True, prog_bar=True)
return total_loss

def prepare_data_loaders(self, dataset):
# Split the dataset
train_size = int(len(dataset) * (1 - self.val_size))
val_size = len(dataset) - train_size
dat_train, dat_val = random_split(dataset, [train_size, val_size], generator=torch.Generator().manual_seed(42))

# Create data loaders
self.train_loader = DataLoader(dat_train, batch_size=int(self.config['batch_size']),
num_workers=0, pin_memory=True, shuffle=True, drop_last=True)
self.val_loader = DataLoader(dat_val, batch_size=int(self.config['batch_size']),
num_workers=0, pin_memory=True, shuffle=False)

def train_dataloader(self):
return self.train_loader

def val_dataloader(self):
return self.val_loader

def predict(self, dataset):
"""
Evaluate the DirectPred model on a given dataset.
Expand Down
18 changes: 1 addition & 17 deletions flexynesis/models/supervised_vae.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ class supervised_vae(pl.LightningModule):
"""
def __init__(self, config, dataset, target_variables, batch_variables = None,
surv_event_var = None, surv_time_var = None, val_size = 0.2, use_loss_weighting = True,
surv_event_var = None, surv_time_var = None, use_loss_weighting = True,
device_type = None):
super(supervised_vae, self).__init__()
self.config = config
Expand All @@ -57,9 +57,6 @@ def __init__(self, config, dataset, target_variables, batch_variables = None,
self.target_variables = self.target_variables + [self.surv_event_var]
self.batch_variables = batch_variables
self.variables = self.target_variables + batch_variables if batch_variables else self.target_variables
self.val_size = val_size

self.dat_train, self.dat_val = self.prepare_data()
self.feature_importances = {}

# sometimes the model may have exploding/vanishing gradients leading to NaN values
Expand Down Expand Up @@ -279,19 +276,6 @@ def validation_step(self, val_batch, batch_idx, log = True):
self.log_dict(losses, on_step=False, on_epoch=True, prog_bar=True)
return total_loss

def prepare_data(self):
lt = int(len(self.dataset)*(1-self.val_size))
lv = len(self.dataset)-lt
dat_train, dat_val = random_split(self.dataset, [lt, lv],
generator=torch.Generator().manual_seed(42))
return dat_train, dat_val

def train_dataloader(self):
return DataLoader(self.dat_train, batch_size=int(self.config['batch_size']), num_workers=0, pin_memory=True, shuffle=True, drop_last=True)

def val_dataloader(self):
return DataLoader(self.dat_val, batch_size=int(self.config['batch_size']), num_workers=0, pin_memory=True, shuffle=False)

def transform(self, dataset):
"""
Transform the input dataset to latent representation.
Expand Down
Loading

0 comments on commit 13a42e5

Please sign in to comment.