Skip to content

Commit

Permalink
minor change
Browse files Browse the repository at this point in the history
  • Loading branch information
shreyasgrampurohit committed Oct 3, 2023
1 parent 23e3aab commit 7f20339
Showing 1 changed file with 13 additions and 6 deletions.
19 changes: 13 additions & 6 deletions proxy-pipeline/train_hist_gradient_boosting_regressor.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,9 +42,9 @@
# Monotonic constraints
# Interaction constraints
flags.DEFINE_bool('warm_start', False, 'When set to True, reuse the solution of the previous call to fit as initialization, otherwise, just erase the previous solution')
# Early stopping
# Scoring
# Validation fraction
flags.DEFINE_bool('early_stopping', 'auto', 'Whether to use early stopping to terminate training when validation score is not improving')
flags.DEFINE_string('scoring', 'loss', 'Scoring parameter to use for early stopping')
flags.DEFINE_float('validation_fraction', 0.1, 'The proportion of training data to set aside as validation set for early stopping')
flags.DEFINE_integer('n_iter_no_change', 10, 'Maximum number of iterations with no improvement to wait before early stopping')
flags.DEFINE_float('tol', 1e-7, 'The absolute tolerance to use when comparing scores during early stopping')
flags.DEFINE_integer('verbose', 0, 'The verbosity level')
Expand Down Expand Up @@ -192,6 +192,10 @@ def main(_):
if not os.path.exists(exp_path):
os.makedirs(exp_path)

# Handle validation fraction flag
if FLAGS.validation_fraction >= 1:
FLAGS.validation_fraction = int(FLAGS.validation_fraction)

# Load the data
if FLAGS.custom_dataset:
actions_path = os.path.join(FLAGS.data_path, 'actions_feasible.csv')
Expand Down Expand Up @@ -224,9 +228,12 @@ def main(_):

# Define the model
regressor = HistGradientBoostingRegressor(loss=FLAGS.loss, learning_rate=FLAGS.learning_rate, max_iter=FLAGS.max_iter,
max_leaf_nodes=FLAGS.max_leaf_nodes, max_depth=FLAGS.max_depth, min_samples_leaf=FLAGS.min_samples_leaf,
l2_regularization=FLAGS.l2_regularization, max_bins=FLAGS.max_bins, warm_start=FLAGS.warm_start,
n_iter_no_change=FLAGS.n_iter_no_change, tol=FLAGS.tol, verbose=FLAGS.verbose, random_state=FLAGS.random_state)
max_leaf_nodes=FLAGS.max_leaf_nodes, max_depth=FLAGS.max_depth,
min_samples_leaf=FLAGS.min_samples_leaf, l2_regularization=FLAGS.l2_regularization,
max_bins=FLAGS.max_bins, warm_start=FLAGS.warm_start, early_stopping=FLAGS.early_stopping,
scoring=FLAGS.scoring, validation_fraction=FLAGS.validation_fraction,
n_iter_no_change=FLAGS.n_iter_no_change, tol=FLAGS.tol, verbose=FLAGS.verbose,
random_state=FLAGS.random_state)

# Train the model
regressor.fit(X_train, y_train[:, 0])
Expand Down

0 comments on commit 7f20339

Please sign in to comment.