Skip to content

Commit

Permalink
Added SKLearn-like random forest Python API. (dmlc#4148)
Browse files Browse the repository at this point in the history
* Added SKLearn-like random forest Python API.

- added XGBRFClassifier and XGBRFRegressor classes to SKL-like xgboost API
- also added n_gpus and gpu_id parameters to SKL classes
- added documentation describing how to use xgboost for random forests,
  as well as existing caveats
  • Loading branch information
canonizer authored and trivialfis committed Mar 12, 2019
1 parent 6fb4c5e commit a36c3ed
Show file tree
Hide file tree
Showing 4 changed files with 240 additions and 55 deletions.
89 changes: 89 additions & 0 deletions doc/rf.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
#########################
Random Forests in XGBoost
#########################

XGBoost is normally used to train gradient-boosted decision trees and other gradient
boosted models. Random forests use the same model representation and inference, as
gradient-boosted decision trees, but a different training algorithm. There are XGBoost
parameters that enable training a forest in a random forest fashion.


****************
With XGBoost API
****************

The following parameters must be set to enable random forest training.

* ``booster`` should be set to ``gbtree``, as we are training forests. Note that as this
is the default, this parameter needn't be set explicitly.
* ``subsample`` must be set to a value less than 1 to enable random selection of training
cases (rows).
* One of ``colsample_by*`` parameters must be set to a value less than 1 to enable random
selection of columns. Normally, ``colsample_bynode`` would be set to a value less than 1
to randomly sample columns at each tree split.
* ``num_parallel_tree`` should be set to the size of the forest being trained.
* ``num_boost_round`` should be set to 1. Note that this is a keyword argument to
``train()``, and is not part of the parameter dictionary.
* ``eta`` (alias: ``learning_rate``) must be set to 1 when training random forest
regression.
* ``random_state`` can be used to seed the random number generator.


Other parameters should be set in a similar way they are set for gradient boosting. For
instance, ``objective`` will typically be ``reg:linear`` for regression and
``binary:logistic`` for classification, ``lambda`` should be set according to a desired
regularization weight, etc.

If both ``num_parallel_tree`` and ``num_boost_round`` are greater than 1, training will
use a combination of random forest and gradient boosting strategy. It will perform
``num_boost_round`` rounds, boosting a random forest of ``num_parallel_tree`` trees at
each round. If early stopping is not enabled, the final model will consist of
``num_parallel_tree`` * ``num_boost_round`` trees.

Here is a sample parameter dictionary for training a random forest on a GPU using
xgboost::

params = {
'colsample_bynode': 0.8,
'learning_rate': 1,
'max_depth': 5,
'num_parallel_tree': 100,
'objective': 'binary:logistic',
'subsample': 0.8,
'tree_method': 'gpu_hist'
}

A random forest model can then be trained as follows::

bst = train(params, dmatrix, num_boost_round=1)


**************************
With Scikit-Learn-Like API
**************************

``XGBRFClassifier`` and ``XGBRFRegressor`` are SKL-like classes that provide random forest
functionality. They are basically versions of ``XGBClassifier`` and ``XGBRegressor`` that
train random forest instead of gradient boosting, and have default values and meaning of
some of the parameters adjusted accordingly. In particular:

* ``n_estimators`` specifies the size of the forest to be trained; it is converted to
``num_parallel_tree``, instead of the number of boosting rounds
* ``learning_rate`` is set to 1 by default
* ``colsample_bynode`` and ``subsample`` are set to 0.8 by default
* ``booster`` is always ``gbtree``

Note that these classes have a smaller selection of parameters compared to using
``train()``. In particular, it is impossible to combine random forests with gradient
boosting using this API.


*******
Caveats
*******

* XGBoost uses 2nd order approximation to the objective function. This can lead to results
that differ from a random forest implementation that uses the exact value of the
objective function.
* XGBoost does not perform replacement when subsampling training cases. Each training case
can occur in a subsampled set either 0 or 1 time.
2 changes: 2 additions & 0 deletions python-package/xgboost/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from . import rabit # noqa
try:
from .sklearn import XGBModel, XGBClassifier, XGBRegressor, XGBRanker
from .sklearn import XGBRFClassifier, XGBRFRegressor
from .plotting import plot_importance, plot_tree, to_graphviz
except ImportError:
pass
Expand All @@ -24,4 +25,5 @@
__all__ = ['DMatrix', 'Booster',
'train', 'cv',
'XGBModel', 'XGBClassifier', 'XGBRegressor', 'XGBRanker',
'XGBRFClassifier', 'XGBRFRegressor',
'plot_importance', 'plot_tree', 'to_graphviz']
153 changes: 114 additions & 39 deletions python-package/xgboost/sklearn.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,9 +61,9 @@ class XGBModel(XGBModelBase):
learning_rate : float
Boosting learning rate (xgb's "eta")
n_estimators : int
Number of boosted trees to fit.
silent : boolean
Whether to print messages while running boosting.
Number of trees to fit.
verbosity : int
The degree of verbosity. Valid values are 0 (silent) - 3 (debug).
objective : string or callable
Specify the learning task and the corresponding learning objective or
a custom objective function to be used (see note below).
Expand All @@ -84,7 +84,9 @@ class XGBModel(XGBModelBase):
colsample_bytree : float
Subsample ratio of columns when constructing each tree.
colsample_bylevel : float
Subsample ratio of columns for each split, in each level.
Subsample ratio of columns for each level.
colsample_bynode : float
Subsample ratio of columns for each split.
reg_alpha : float (xgb's alpha)
L1 regularization term on weights
reg_lambda : float (xgb's lambda)
Expand Down Expand Up @@ -132,18 +134,18 @@ class XGBModel(XGBModelBase):
"""

def __init__(self, max_depth=3, learning_rate=0.1, n_estimators=100,
silent=True, objective="reg:linear", booster='gbtree',
n_jobs=1, nthread=None, gamma=0, min_child_weight=1, max_delta_step=0,
subsample=1, colsample_bytree=1, colsample_bylevel=1,
reg_alpha=0, reg_lambda=1, scale_pos_weight=1,
verbosity=1, objective="reg:linear", booster='gbtree',
n_jobs=1, nthread=None, gamma=0, min_child_weight=1,
max_delta_step=0, subsample=1, colsample_bytree=1, colsample_bylevel=1,
colsample_bynode=1, reg_alpha=0, reg_lambda=1, scale_pos_weight=1,
base_score=0.5, random_state=0, seed=None, missing=None,
importance_type="gain", **kwargs):
if not SKLEARN_INSTALLED:
raise XGBoostError('sklearn needs to be installed in order to use this module')
self.max_depth = max_depth
self.learning_rate = learning_rate
self.n_estimators = n_estimators
self.silent = silent
self.verbosity = verbosity
self.objective = objective
self.booster = booster
self.gamma = gamma
Expand All @@ -152,6 +154,7 @@ def __init__(self, max_depth=3, learning_rate=0.1, n_estimators=100,
self.subsample = subsample
self.colsample_bytree = colsample_bytree
self.colsample_bylevel = colsample_bylevel
self.colsample_bynode = colsample_bynode
self.reg_alpha = reg_alpha
self.reg_lambda = reg_lambda
self.scale_pos_weight = scale_pos_weight
Expand Down Expand Up @@ -237,12 +240,14 @@ def get_xgb_params(self):
else:
xgb_params['nthread'] = n_jobs

xgb_params['verbosity'] = 0 if self.silent else 0

if xgb_params['nthread'] <= 0:
xgb_params.pop('nthread', None)
return xgb_params

def get_num_boosting_rounds(self):
"""Gets the number of xgboost boosting rounds."""
return self.n_estimators

def save_model(self, fname):
"""
Save the model to a file.
Expand Down Expand Up @@ -371,7 +376,7 @@ def fit(self, X, y, sample_weight=None, eval_set=None, eval_metric=None,
params.update({'eval_metric': eval_metric})

self._Booster = train(params, trainDmatrix,
self.n_estimators, evals=evals,
self.get_num_boosting_rounds(), evals=evals,
early_stopping_rounds=early_stopping_rounds,
evals_result=evals_result, obj=obj, feval=feval,
verbose_eval=verbose, xgb_model=xgb_model,
Expand Down Expand Up @@ -583,21 +588,22 @@ class XGBClassifier(XGBModel, XGBClassifierBase):
__doc__ = "Implementation of the scikit-learn API for XGBoost classification.\n\n" \
+ '\n'.join(XGBModel.__doc__.split('\n')[2:])

def __init__(self, max_depth=3, learning_rate=0.1,
n_estimators=100, silent=True,
def __init__(self, max_depth=3, learning_rate=0.1, n_estimators=100, verbosity=1,
objective="binary:logistic", booster='gbtree',
n_jobs=1, nthread=None, gamma=0, min_child_weight=1,
max_delta_step=0, subsample=1, colsample_bytree=1, colsample_bylevel=1,
reg_alpha=0, reg_lambda=1, scale_pos_weight=1,
n_jobs=1, nthread=None, gamma=0, min_child_weight=1, max_delta_step=0,
subsample=1, colsample_bytree=1, colsample_bylevel=1,
colsample_bynode=1, reg_alpha=0, reg_lambda=1, scale_pos_weight=1,
base_score=0.5, random_state=0, seed=None, missing=None, **kwargs):
super(XGBClassifier, self).__init__(max_depth, learning_rate,
n_estimators, silent, objective, booster,
n_jobs, nthread, gamma, min_child_weight,
max_delta_step, subsample,
colsample_bytree, colsample_bylevel,
reg_alpha, reg_lambda,
scale_pos_weight, base_score,
random_state, seed, missing, **kwargs)
super(XGBClassifier, self).__init__(
max_depth=max_depth, learning_rate=learning_rate, n_estimators=n_estimators,
verbosity=verbosity, objective=objective, booster=booster,
n_jobs=n_jobs, nthread=nthread, gamma=gamma,
min_child_weight=min_child_weight, max_delta_step=max_delta_step,
subsample=subsample, colsample_bytree=colsample_bytree,
colsample_bylevel=colsample_bylevel, colsample_bynode=colsample_bynode,
reg_alpha=reg_alpha, reg_lambda=reg_lambda, scale_pos_weight=scale_pos_weight,
base_score=base_score, random_state=random_state, seed=seed, missing=missing,
**kwargs)

def fit(self, X, y, sample_weight=None, eval_set=None, eval_metric=None,
early_stopping_rounds=None, verbose=True, xgb_model=None,
Expand Down Expand Up @@ -705,9 +711,8 @@ def fit(self, X, y, sample_weight=None, eval_set=None, eval_metric=None,
train_dmatrix = DMatrix(X, label=training_labels,
missing=self.missing, nthread=self.n_jobs)

self._Booster = train(xgb_options, train_dmatrix, self.n_estimators,
evals=evals,
early_stopping_rounds=early_stopping_rounds,
self._Booster = train(xgb_options, train_dmatrix, self.get_num_boosting_rounds(),
evals=evals, early_stopping_rounds=early_stopping_rounds,
evals_result=evals_result, obj=obj, feval=feval,
verbose_eval=verbose, xgb_model=xgb_model,
callbacks=callbacks)
Expand Down Expand Up @@ -863,12 +868,76 @@ def evals_result(self):
return evals_result


class XGBRFClassifier(XGBClassifier):
# pylint: disable=missing-docstring
__doc__ = "Implementation of the scikit-learn API "\
+ "for XGBoost random forest classification.\n\n"\
+ '\n'.join(XGBModel.__doc__.split('\n')[2:])

def __init__(self, max_depth=3, learning_rate=1, n_estimators=100, verbosity=1,
objective="binary:logistic", n_jobs=1, nthread=None, gamma=0,
min_child_weight=1, max_delta_step=0, subsample=0.8, colsample_bytree=1,
colsample_bylevel=1, colsample_bynode=0.8, reg_alpha=0, reg_lambda=1,
scale_pos_weight=1, base_score=0.5, random_state=0, seed=None,
missing=None, **kwargs):
super(XGBRFClassifier, self).__init__(
max_depth=max_depth, learning_rate=learning_rate, n_estimators=n_estimators,
verbosity=verbosity, objective=objective, booster='gbtree',
n_jobs=n_jobs, nthread=nthread, gamma=gamma,
min_child_weight=min_child_weight, max_delta_step=max_delta_step,
subsample=subsample, colsample_bytree=colsample_bytree,
colsample_bylevel=colsample_bylevel, colsample_bynode=colsample_bynode,
reg_alpha=reg_alpha, reg_lambda=reg_lambda, scale_pos_weight=scale_pos_weight,
base_score=base_score, random_state=random_state, seed=seed, missing=missing,
**kwargs)

def get_xgb_params(self):
params = super(XGBRFClassifier, self).get_xgb_params()
params['num_parallel_tree'] = self.n_estimators
return params

def get_num_boosting_rounds(self):
return 1


class XGBRegressor(XGBModel, XGBRegressorBase):
# pylint: disable=missing-docstring
__doc__ = "Implementation of the scikit-learn API for XGBoost regression.\n\n"\
+ '\n'.join(XGBModel.__doc__.split('\n')[2:])


class XGBRFRegressor(XGBRegressor):
# pylint: disable=missing-docstring
__doc__ = "Implementation of the scikit-learn API "\
+ "for XGBoost random forest regression.\n\n"\
+ '\n'.join(XGBModel.__doc__.split('\n')[2:])

def __init__(self, max_depth=3, learning_rate=1, n_estimators=100, verbosity=1,
objective="reg:linear", n_jobs=1, nthread=None, gamma=0,
min_child_weight=1, max_delta_step=0, subsample=0.8, colsample_bytree=1,
colsample_bylevel=1, colsample_bynode=0.8, reg_alpha=0, reg_lambda=1,
scale_pos_weight=1, base_score=0.5, random_state=0, seed=None,
missing=None, **kwargs):
super(XGBRFRegressor, self).__init__(
max_depth=max_depth, learning_rate=learning_rate, n_estimators=n_estimators,
verbosity=verbosity, objective=objective, booster='gbtree',
n_jobs=n_jobs, nthread=nthread, gamma=gamma,
min_child_weight=min_child_weight, max_delta_step=max_delta_step,
subsample=subsample, colsample_bytree=colsample_bytree,
colsample_bylevel=colsample_bylevel, colsample_bynode=colsample_bynode,
reg_alpha=reg_alpha, reg_lambda=reg_lambda, scale_pos_weight=scale_pos_weight,
base_score=base_score, random_state=random_state, seed=seed, missing=missing,
**kwargs)

def get_xgb_params(self):
params = super(XGBRFRegressor, self).get_xgb_params()
params['num_parallel_tree'] = self.n_estimators
return params

def get_num_boosting_rounds(self):
return 1


class XGBRanker(XGBModel):
# pylint: disable=missing-docstring,too-many-arguments,invalid-name
"""Implementation of the Scikit-Learn API for XGBoost Ranking.
Expand All @@ -881,8 +950,8 @@ class XGBRanker(XGBModel):
Boosting learning rate (xgb's "eta")
n_estimators : int
Number of boosted trees to fit.
silent : boolean
Whether to print messages while running boosting.
verbosity : int
The degree of verbosity. Valid values are 0 (silent) - 3 (debug).
objective : string
Specify the learning task and the corresponding learning objective.
The objective name must start with "rank:".
Expand All @@ -903,7 +972,9 @@ class XGBRanker(XGBModel):
colsample_bytree : float
Subsample ratio of columns when constructing each tree.
colsample_bylevel : float
Subsample ratio of columns for each split, in each level.
Subsample ratio of columns for each level.
colsample_bynode : float
Subsample ratio of columns for each split.
reg_alpha : float (xgb's alpha)
L1 regularization term on weights
reg_lambda : float (xgb's lambda)
Expand Down Expand Up @@ -966,18 +1037,22 @@ class XGBRanker(XGBModel):
"""

def __init__(self, max_depth=3, learning_rate=0.1, n_estimators=100,
silent=True, objective="rank:pairwise", booster='gbtree',
verbosity=1, objective="rank:pairwise", booster='gbtree',
n_jobs=-1, nthread=None, gamma=0, min_child_weight=1, max_delta_step=0,
subsample=1, colsample_bytree=1, colsample_bylevel=1,
subsample=1, colsample_bytree=1, colsample_bylevel=1, colsample_bynode=1,
reg_alpha=0, reg_lambda=1, scale_pos_weight=1,
base_score=0.5, random_state=0, seed=None, missing=None, **kwargs):

super(XGBRanker, self).__init__(max_depth, learning_rate,
n_estimators, silent, objective, booster,
n_jobs, nthread, gamma, min_child_weight, max_delta_step,
subsample, colsample_bytree, colsample_bylevel,
reg_alpha, reg_lambda, scale_pos_weight,
base_score, random_state, seed, missing)
super(XGBRanker, self).__init__(
max_depth=max_depth, learning_rate=learning_rate, n_estimators=n_estimators,
verbosity=verbosity, objective=objective, booster=booster,
n_jobs=n_jobs, nthread=nthread, gamma=gamma,
min_child_weight=min_child_weight, max_delta_step=max_delta_step,
subsample=subsample, colsample_bytree=colsample_bytree,
colsample_bylevel=colsample_bylevel, colsample_bynode=colsample_bynode,
reg_alpha=reg_alpha, reg_lambda=reg_lambda,
scale_pos_weight=scale_pos_weight, base_score=base_score,
random_state=random_state, seed=seed, missing=missing, **kwargs)
if callable(self.objective):
raise ValueError("custom objective function not supported by XGBRanker")
elif "rank:" not in self.objective:
Expand Down
Loading

0 comments on commit a36c3ed

Please sign in to comment.