Skip to content

Commit

Permalink
[python] Re-enable scikit-learn 0.22+ support (#2949)
Browse files Browse the repository at this point in the history
* Revert "specify the last supported version of scikit-learn (#2637)"

This reverts commit d100277.

* ban scikit-learn 0.22.0 and skip broken test

* fix updated test

* fix lint test

* Revert "fix lint test"

This reverts commit 8b4db08.
  • Loading branch information
StrikerRUS authored Apr 10, 2020
1 parent 505a145 commit c633c6c
Show file tree
Hide file tree
Showing 12 changed files with 33 additions and 27 deletions.
2 changes: 1 addition & 1 deletion .ci/test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ if [[ $TASK == "r-package" ]]; then
exit 0
fi

conda install -q -y -n $CONDA_ENV joblib matplotlib numpy pandas psutil pytest python-graphviz "scikit-learn<=0.21.3" scipy
conda install -q -y -n $CONDA_ENV joblib matplotlib numpy pandas psutil pytest python-graphviz scikit-learn scipy

if [[ $OS_NAME == "macos" ]] && [[ $COMPILER == "clang" ]]; then
# fix "OMP: Error #15: Initializing libiomp5.dylib, but found libomp.dylib already initialized." (OpenMP library conflict due to conda's MKL)
Expand Down
2 changes: 1 addition & 1 deletion .ci/test_windows.ps1
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ conda init powershell
conda activate
conda config --set always_yes yes --set changeps1 no
conda update -q -y conda
conda create -q -y -n $env:CONDA_ENV python=$env:PYTHON_VERSION joblib matplotlib numpy pandas psutil pytest python-graphviz "scikit-learn<=0.21.3" scipy ; Check-Output $?
conda create -q -y -n $env:CONDA_ENV python=$env:PYTHON_VERSION joblib matplotlib numpy pandas psutil pytest python-graphviz scikit-learn scipy ; Check-Output $?
conda activate $env:CONDA_ENV

if ($env:TASK -eq "regular") {
Expand Down
2 changes: 1 addition & 1 deletion docker/dockerfile-python
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ RUN apt-get update && \
export PATH="$CONDA_DIR/bin:$PATH" && \
conda config --set always_yes yes --set changeps1 no && \
# lightgbm
conda install -q -y numpy scipy "scikit-learn<=0.21.3" pandas && \
conda install -q -y numpy scipy scikit-learn pandas && \
git clone --recursive --branch stable --depth 1 https://github.com/Microsoft/LightGBM && \
cd LightGBM/python-package && python setup.py install && \
# clean
Expand Down
4 changes: 2 additions & 2 deletions docker/gpu/dockerfile.gpu
Original file line number Diff line number Diff line change
Expand Up @@ -75,8 +75,8 @@ RUN echo "export PATH=$CONDA_DIR/bin:"'$PATH' > /etc/profile.d/conda.sh && \
rm ~/miniconda.sh

RUN conda config --set always_yes yes --set changeps1 no && \
conda create -y -q -n py2 python=2.7 mkl numpy scipy "scikit-learn<=0.21.3" jupyter notebook ipython pandas matplotlib && \
conda create -y -q -n py3 python=3.6 mkl numpy scipy "scikit-learn<=0.21.3" jupyter notebook ipython pandas matplotlib
conda create -y -q -n py2 python=2.7 mkl numpy scipy scikit-learn jupyter notebook ipython pandas matplotlib && \
conda create -y -q -n py3 python=3.6 mkl numpy scipy scikit-learn jupyter notebook ipython pandas matplotlib

#################################################################################################################
# LightGBM
Expand Down
4 changes: 2 additions & 2 deletions docs/GPU-Tutorial.rst
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
LightGBM GPU Tutorial
LightGBM GPU Tutorial
=====================

The purpose of this document is to give you a quick step-by-step tutorial on GPU training.
Expand Down Expand Up @@ -78,7 +78,7 @@ If you want to use the Python interface of LightGBM, you can install it now (alo
::

sudo apt-get -y install python-pip
sudo -H pip install setuptools numpy scipy "scikit-learn<=0.21.3" -U
sudo -H pip install setuptools numpy scipy scikit-learn -U
cd python-package/
sudo python setup.py install --precompile
cd ..
Expand Down
4 changes: 0 additions & 4 deletions docs/Python-API.rst
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,6 @@ Training API
Scikit-learn API
----------------

.. warning::

The last supported version of scikit-learn is ``0.21.3``. Our estimators are incompatible with newer versions.

.. autosummary::
:toctree: pythonapi/

Expand Down
4 changes: 2 additions & 2 deletions docs/Python-Intro.rst
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,11 @@ Install
-------

Install Python-package dependencies,
``setuptools``, ``wheel``, ``numpy`` and ``scipy`` are required, ``scikit-learn<=0.21.3`` is required for sklearn interface and recommended:
``setuptools``, ``wheel``, ``numpy`` and ``scipy`` are required, ``scikit-learn`` is required for sklearn interface and recommended:

::

pip install setuptools wheel numpy scipy "scikit-learn<=0.21.3" -U
pip install setuptools wheel numpy scipy scikit-learn -U

Refer to `Python-package`_ folder for the installation guide.

Expand Down
2 changes: 1 addition & 1 deletion examples/python-guide/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ You should install LightGBM [Python-package](https://github.com/microsoft/LightG
You also need scikit-learn, pandas, matplotlib (only for plot example), and scipy (only for logistic regression example) to run the examples, but they are not required for the package itself. You can install them with pip:

```
pip install "scikit-learn<=0.21.3" pandas matplotlib scipy -U
pip install scikit-learn pandas matplotlib scipy -U
```

Now you can run examples in this folder, for example:
Expand Down
19 changes: 13 additions & 6 deletions python-package/lightgbm/compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,16 +116,24 @@ class DataTable(object):
from sklearn.preprocessing import LabelEncoder
from sklearn.utils.class_weight import compute_sample_weight
from sklearn.utils.multiclass import check_classification_targets
from sklearn.utils.validation import (assert_all_finite, check_X_y,
check_array, check_consistent_length)
from sklearn.utils.validation import assert_all_finite, check_X_y, check_array
try:
from sklearn.model_selection import StratifiedKFold, GroupKFold
from sklearn.exceptions import NotFittedError
except ImportError:
from sklearn.cross_validation import StratifiedKFold, GroupKFold
from sklearn.utils.validation import NotFittedError
try:
from sklearn.utils.validation import _check_sample_weight
except ImportError:
from sklearn.utils.validation import check_consistent_length

# dummy function to support older version of scikit-learn
def _check_sample_weight(sample_weight, X, dtype=None):
check_consistent_length(sample_weight, X)
return sample_weight

SKLEARN_INSTALLED = True
from sklearn import __version__ as SKLEARN_VERSION
_LGBMModelBase = BaseEstimator
_LGBMRegressorBase = RegressorMixin
_LGBMClassifierBase = ClassifierMixin
Expand All @@ -135,13 +143,12 @@ class DataTable(object):
_LGBMGroupKFold = GroupKFold
_LGBMCheckXY = check_X_y
_LGBMCheckArray = check_array
_LGBMCheckConsistentLength = check_consistent_length
_LGBMCheckSampleWeight = _check_sample_weight
_LGBMAssertAllFinite = assert_all_finite
_LGBMCheckClassificationTargets = check_classification_targets
_LGBMComputeSampleWeight = compute_sample_weight
except ImportError:
SKLEARN_INSTALLED = False
SKLEARN_VERSION = '0.0.0'
_LGBMModelBase = object
_LGBMClassifierBase = object
_LGBMRegressorBase = object
Expand All @@ -151,7 +158,7 @@ class DataTable(object):
_LGBMGroupKFold = None
_LGBMCheckXY = None
_LGBMCheckArray = None
_LGBMCheckConsistentLength = None
_LGBMCheckSampleWeight = None
_LGBMAssertAllFinite = None
_LGBMCheckClassificationTargets = None
_LGBMComputeSampleWeight = None
Expand Down
10 changes: 4 additions & 6 deletions python-package/lightgbm/sklearn.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,9 @@
import numpy as np

from .basic import Dataset, LightGBMError, _ConfigAliases
from .compat import (SKLEARN_INSTALLED, SKLEARN_VERSION, _LGBMClassifierBase,
from .compat import (SKLEARN_INSTALLED, _LGBMClassifierBase,
LGBMNotFittedError, _LGBMLabelEncoder, _LGBMModelBase,
_LGBMRegressorBase, _LGBMCheckXY, _LGBMCheckArray, _LGBMCheckConsistentLength,
_LGBMRegressorBase, _LGBMCheckXY, _LGBMCheckArray, _LGBMCheckSampleWeight,
_LGBMAssertAllFinite, _LGBMCheckClassificationTargets, _LGBMComputeSampleWeight,
argc_, range_, zip_, string_type, DataFrame, DataTable)
from .engine import train
Expand Down Expand Up @@ -298,9 +298,6 @@ def __init__(self, boosting_type='gbdt', num_leaves=31, max_depth=-1,
"""
if not SKLEARN_INSTALLED:
raise LightGBMError('Scikit-learn is required for this module')
elif SKLEARN_VERSION > '0.21.3':
raise RuntimeError("The last supported version of scikit-learn is 0.21.3.\n"
"Found version: {0}.".format(SKLEARN_VERSION))

self.boosting_type = boosting_type
self.objective = objective
Expand Down Expand Up @@ -547,7 +544,8 @@ def fit(self, X, y,

if not isinstance(X, (DataFrame, DataTable)):
_X, _y = _LGBMCheckXY(X, y, accept_sparse=True, force_all_finite=False, ensure_min_samples=2)
_LGBMCheckConsistentLength(_X, _y, sample_weight)
if sample_weight is not None:
sample_weight = _LGBMCheckSampleWeight(sample_weight, _X)
else:
_X, _y = X, y

Expand Down
2 changes: 1 addition & 1 deletion python-package/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,7 +276,7 @@ def run(self):
install_requires=[
'numpy',
'scipy',
'scikit-learn<=0.21.3'
'scikit-learn!=0.22.0'
],
maintainer='Guolin Ke',
maintainer_email='[email protected]',
Expand Down
5 changes: 5 additions & 0 deletions tests/python_package_test/test_sklearn.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,6 +293,11 @@ def test_sklearn_integration(self):
check_name = check.func.__name__ if hasattr(check, 'func') else check.__name__
if check_name == 'check_estimators_nan_inf':
continue # skip test because LightGBM deals with nan
elif check_name == "check_no_attributes_set_in_init":
# skip test because scikit-learn incorrectly asserts that
# private attributes cannot be set in __init__
# (see https://github.com/microsoft/LightGBM/issues/2628)
continue
try:
check(name, estimator)
except SkipTest as message:
Expand Down

0 comments on commit c633c6c

Please sign in to comment.