diff --git a/flexynesis/__main__.py b/flexynesis/__main__.py index 1a21ed5..3a6476c 100644 --- a/flexynesis/__main__.py +++ b/flexynesis/__main__.py @@ -53,6 +53,7 @@ def main(): parser.add_argument("--prefix", help="Job prefix to use for output files", type=str, default = 'job') parser.add_argument("--log_transform", help="whether to apply log-transformation to input data matrices", type=str, choices=['True', 'False'], default = 'False') parser.add_argument("--early_stop_patience", help="How many epochs to wait when no improvements in validation loss is observed (default: 10; set to -1 to disable early stopping)", type=int, default = 10) + parser.add_argument("--hpo_patience", help="How many hyperparamater optimisation iterations to wait for when no improvements are observed (default: 10; set to 0 to disable early stopping)", type=int, default = 10) parser.add_argument("--use_cv", action="store_true", help="(Optional) If set, the a 5-fold cross-validation training will be done. Otherwise, a single trainign on 80% of the dataset is done.") parser.add_argument("--use_loss_weighting", help="whether to apply loss-balancing using uncertainty weights method", type=str, choices=['True', 'False'], default = 'True') @@ -218,7 +219,7 @@ class AvailableModels(NamedTuple): output_layers = output_layers) # do a hyperparameter search training multiple models and get the best_configuration - model, best_params = tuner.perform_tuning() + model, best_params = tuner.perform_tuning(hpo_patience = args.hpo_patience) # if fine-tuning is enabled; fine tune the model on a portion of test samples if args.finetuning_samples > 0: diff --git a/flexynesis/main.py b/flexynesis/main.py index a1c8dbb..12e8c07 100644 --- a/flexynesis/main.py +++ b/flexynesis/main.py @@ -208,17 +208,20 @@ def objective(self, params, current_step, total_steps, full_train = False): avg_epochs = int(np.mean(epochs)) return avg_val_loss, avg_epochs, model - def perform_tuning(self): + def perform_tuning(self, hpo_patience = 0): opt = Optimizer(dimensions=self.space, n_initial_points=10, acq_func="gp_hedge", acq_optimizer="auto") best_loss = np.inf best_params = None best_epochs = 0 best_model = None - + + # keep track of the streak of HPO iterations without improvement + no_improvement_count = 0 + with tqdm(total=self.n_iter, desc='Tuning Progress') as pbar: for i in range(self.n_iter): - np.int = int + np.int = int # Ensure int type is correctly handled suggested_params_list = opt.ask() suggested_params_dict = {param.name: value for param, value in zip(self.space, suggested_params_list)} loss, avg_epochs, model = self.objective(suggested_params_dict, current_step=i+1, total_steps=self.n_iter) @@ -229,21 +232,31 @@ def perform_tuning(self): if loss < best_loss: best_loss = loss best_params = suggested_params_list - best_epochs = avg_epochs + best_epochs = avg_epochs best_model = model + no_improvement_count = 0 # Reset the no improvement counter + else: + no_improvement_count += 1 # Increment the no improvement counter # Print result of each iteration pbar.set_postfix({'Iteration': i+1, 'Best Loss': best_loss}) pbar.update(1) - # convert best params to dict + + # Early stopping condition + if no_improvement_count >= hpo_patience & hpo_patience > 0: + print(f"No improvement in best loss for {hpo_patience} iterations, stopping hyperparameter optimisation early.") + break # Break out of the loop + + # Convert best parameters from list to dictionary and include epochs best_params_dict = {param.name: value for param, value in zip(self.space, best_params)} - best_params_dict['epochs'] = avg_epochs + best_params_dict['epochs'] = best_epochs + if self.use_cv: - # build a final model based on best params if a cross-validation + # Build a final model based on best parameters if using cross-validation print(f"[INFO] Building a final model using best params: {best_params_dict}") - best_model = self.objective(best_params_dict, current_step=0, total_steps=1, full_train=True) - return best_model, best_params_dict - + best_model = self.objective(best_params_dict, current_step=0, total_steps=1, full_train=True) + + return best_model, best_params_dict def init_early_stopping(self): """Initialize the early stopping callback.""" return EarlyStopping(