Skip to content

Commit

Permalink
Sklearn kwargs (#2338)
Browse files Browse the repository at this point in the history
* Added kwargs support for Sklearn API

* Updated NEWS and CONTRIBUTORS

* Fixed CONTRIBUTORS.md

* Added clarification of **kwargs and test for proper usage

* Fixed lint error

* Fixed more lint errors and clf assigned but never used

* Fixed more lint errors

* Fixed more lint errors

* Fixed issue with changes from different branch bleeding over

* Fixed issue with changes from other branch bleeding over

* Added note that kwargs may not be compatible with Sklearn

* Fixed linting on kwargs note
  • Loading branch information
gaw89 authored and terrytangyuan committed May 24, 2017
1 parent 6cea1e3 commit 0f3a404
Show file tree
Hide file tree
Showing 4 changed files with 38 additions and 4 deletions.
1 change: 1 addition & 0 deletions CONTRIBUTORS.md
Original file line number Diff line number Diff line change
Expand Up @@ -65,3 +65,4 @@ List of Contributors
* [Adam Pocock](https://github.com/Craigacp)
* [Rory Mitchell](https://github.com/RAMitchell)
- Rory is the author of the GPU plugin and also contributed the cmake build system and windows continuous integration
* [Gideon Whitehead](https://github.com/gaw89)
3 changes: 3 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@ XGBoost Change Log
This file records the changes in xgboost library in reverse chronological order.

## in progress version
* Updated Sklearn API
- Updated to allow use of all XGBoost parameters via **kwargs.
- Updated nthread to n_jobs and seed to random_state (as per Sklearn convention).
* Refactored gbm to allow more friendly cache strategy
- Specialized some prediction routine
* Automatically remove nan from input data when it is sparse.
Expand Down
18 changes: 14 additions & 4 deletions python-package/xgboost/sklearn.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,14 @@ class XGBModel(XGBModelBase):
missing : float, optional
Value in the data which needs to be present as a missing value. If
None, defaults to np.nan.
**kwargs : dict, optional
Keyword arguments for XGBoost Booster object. Full documentation of parameters can
be found here: https://github.com/dmlc/xgboost/blob/master/doc/parameter.md.
Attempting to set a parameter via the constructor args and **kwargs dict simultaneously
will result in a TypeError.
Note:
**kwargs is unsupported by Sklearn. We do not guarantee that parameters passed via
this argument will interact properly with Sklearn.
Note
----
Expand All @@ -124,7 +132,7 @@ def __init__(self, max_depth=3, learning_rate=0.1, n_estimators=100,
n_jobs=1, nthread=None, gamma=0, min_child_weight=1, max_delta_step=0,
subsample=1, colsample_bytree=1, colsample_bylevel=1,
reg_alpha=0, reg_lambda=1, scale_pos_weight=1,
base_score=0.5, random_state=0, seed=None, missing=None):
base_score=0.5, random_state=0, seed=None, missing=None, **kwargs):
if not SKLEARN_INSTALLED:
raise XGBoostError('sklearn needs to be installed in order to use this module')
self.max_depth = max_depth
Expand All @@ -133,7 +141,6 @@ def __init__(self, max_depth=3, learning_rate=0.1, n_estimators=100,
self.silent = silent
self.objective = objective
self.booster = booster

self.nthread = nthread
self.gamma = gamma
self.min_child_weight = min_child_weight
Expand All @@ -146,6 +153,7 @@ def __init__(self, max_depth=3, learning_rate=0.1, n_estimators=100,
self.scale_pos_weight = scale_pos_weight
self.base_score = base_score
self.missing = missing if missing is not None else np.nan
self.kwargs = kwargs
self._Booster = None
if seed:
warnings.warn('The seed parameter is deprecated as of version .6.'
Expand Down Expand Up @@ -192,6 +200,8 @@ def get_booster(self):
def get_params(self, deep=False):
"""Get parameter.s"""
params = super(XGBModel, self).get_params(deep=deep)
if isinstance(self.kwargs, dict): # if kwargs is a dict, update params accordingly
params.update(self.kwargs)
if params['missing'] is np.nan:
params['missing'] = None # sklearn doesn't handle nan. see #4725
if not params.get('eval_metric', True):
Expand Down Expand Up @@ -388,15 +398,15 @@ def __init__(self, max_depth=3, learning_rate=0.1,
n_jobs=1, nthread=None, gamma=0, min_child_weight=1,
max_delta_step=0, subsample=1, colsample_bytree=1, colsample_bylevel=1,
reg_alpha=0, reg_lambda=1, scale_pos_weight=1,
base_score=0.5, random_state=0, seed=None, missing=None):
base_score=0.5, random_state=0, seed=None, missing=None, **kwargs):
super(XGBClassifier, self).__init__(max_depth, learning_rate,
n_estimators, silent, objective, booster,
n_jobs, nthread, gamma, min_child_weight,
max_delta_step, subsample,
colsample_bytree, colsample_bylevel,
reg_alpha, reg_lambda,
scale_pos_weight, base_score,
random_state, seed, missing)
random_state, seed, missing, **kwargs)

def fit(self, X, y, sample_weight=None, eval_set=None, eval_metric=None,
early_stopping_rounds=None, verbose=True):
Expand Down
20 changes: 20 additions & 0 deletions tests/python/test_with_sklearn.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import xgboost as xgb
import testing as tm
import warnings
from nose.tools import raises

rng = np.random.RandomState(1994)

Expand Down Expand Up @@ -363,3 +364,22 @@ def test_nthread_deprecation():
with warnings.catch_warnings(record=True) as w:
xgb.XGBClassifier(nthread=1)
assert w[0].category == DeprecationWarning


def test_kwargs():
tm._skip_if_no_sklearn()

params = {'updater': 'grow_gpu', 'subsample': .5, 'n_jobs': -1}
clf = xgb.XGBClassifier(n_estimators=1000, **params)
assert clf.get_params()['updater'] == 'grow_gpu'
assert clf.get_params()['subsample'] == .5
assert clf.get_params()['n_estimators'] == 1000


@raises(TypeError)
def test_kwargs_error():
tm._skip_if_no_sklearn()

params = {'updater': 'grow_gpu', 'subsample': .5, 'n_jobs': -1}
clf = xgb.XGBClassifier(n_jobs=1000, **params)
assert isinstance(clf, xgb.XGBClassifier)

0 comments on commit 0f3a404

Please sign in to comment.