diff --git a/proxy-pipeline/train_hist_gradient_boosting_regressor.py b/proxy-pipeline/train_hist_gradient_boosting_regressor.py index 2d5add73..a650a5f5 100644 --- a/proxy-pipeline/train_hist_gradient_boosting_regressor.py +++ b/proxy-pipeline/train_hist_gradient_boosting_regressor.py @@ -42,9 +42,9 @@ # Monotonic constraints # Interaction constraints flags.DEFINE_bool('warm_start', False, 'When set to True, reuse the solution of the previous call to fit as initialization, otherwise, just erase the previous solution') -# Early stopping -# Scoring -# Validation fraction +flags.DEFINE_bool('early_stopping', 'auto', 'Whether to use early stopping to terminate training when validation score is not improving') +flags.DEFINE_string('scoring', 'loss', 'Scoring parameter to use for early stopping') +flags.DEFINE_float('validation_fraction', 0.1, 'The proportion of training data to set aside as validation set for early stopping') flags.DEFINE_integer('n_iter_no_change', 10, 'Maximum number of iterations with no improvement to wait before early stopping') flags.DEFINE_float('tol', 1e-7, 'The absolute tolerance to use when comparing scores during early stopping') flags.DEFINE_integer('verbose', 0, 'The verbosity level') @@ -192,6 +192,10 @@ def main(_): if not os.path.exists(exp_path): os.makedirs(exp_path) + # Handle validation fraction flag + if FLAGS.validation_fraction >= 1: + FLAGS.validation_fraction = int(FLAGS.validation_fraction) + # Load the data if FLAGS.custom_dataset: actions_path = os.path.join(FLAGS.data_path, 'actions_feasible.csv') @@ -224,9 +228,12 @@ def main(_): # Define the model regressor = HistGradientBoostingRegressor(loss=FLAGS.loss, learning_rate=FLAGS.learning_rate, max_iter=FLAGS.max_iter, - max_leaf_nodes=FLAGS.max_leaf_nodes, max_depth=FLAGS.max_depth, min_samples_leaf=FLAGS.min_samples_leaf, - l2_regularization=FLAGS.l2_regularization, max_bins=FLAGS.max_bins, warm_start=FLAGS.warm_start, - n_iter_no_change=FLAGS.n_iter_no_change, tol=FLAGS.tol, verbose=FLAGS.verbose, random_state=FLAGS.random_state) + max_leaf_nodes=FLAGS.max_leaf_nodes, max_depth=FLAGS.max_depth, + min_samples_leaf=FLAGS.min_samples_leaf, l2_regularization=FLAGS.l2_regularization, + max_bins=FLAGS.max_bins, warm_start=FLAGS.warm_start, early_stopping=FLAGS.early_stopping, + scoring=FLAGS.scoring, validation_fraction=FLAGS.validation_fraction, + n_iter_no_change=FLAGS.n_iter_no_change, tol=FLAGS.tol, verbose=FLAGS.verbose, + random_state=FLAGS.random_state) # Train the model regressor.fit(X_train, y_train[:, 0])