From 2a00b6ffbc979fcbe68a8485aaa78d80005b163b Mon Sep 17 00:00:00 2001 From: James Lamb Date: Tue, 2 Mar 2021 07:29:08 -0600 Subject: [PATCH] [dask] [ci] add support for scikit-learn 0.24+ in tests (fixes #4031) (#4032) * [dask] [ci] add support for scikit-learn 0.24+ in tests (fixes #4031) * Update tests/python_package_test/test_dask.py Co-authored-by: Nikita Titov * try upgrading mixtexsetup * they changed the executable name UGH * more changes for executable name * another path change * changing package mirrors * undo experiments Co-authored-by: Nikita Titov --- tests/python_package_test/test_dask.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/tests/python_package_test/test_dask.py b/tests/python_package_test/test_dask.py index 6b112b705a3e..5f7784190e4b 100644 --- a/tests/python_package_test/test_dask.py +++ b/tests/python_package_test/test_dask.py @@ -28,12 +28,16 @@ from dask.array.utils import assert_eq from dask.distributed import Client, LocalCluster, default_client, wait from distributed.utils_test import client, cluster_fixture, gen_cluster, loop +from pkg_resources import parse_version from scipy.sparse import csr_matrix from scipy.stats import spearmanr +from sklearn import __version__ as sk_version from sklearn.datasets import make_blobs, make_regression from .utils import make_ranking +sk_version = parse_version(sk_version) + # time, in seconds, to wait for the Dask client to close. Used to avoid teardown errors # see https://distributed.dask.org/en/latest/api.html#distributed.Client.close CLIENT_CLOSE_TIMEOUT = 120 @@ -1253,5 +1257,9 @@ def test_sklearn_integration(estimator, check, client): # this test is separate because it takes a not-yet-constructed estimator @pytest.mark.parametrize("estimator", list(_tested_estimators())) def test_parameters_default_constructible(estimator): - name, Estimator = estimator.__class__.__name__, estimator.__class__ + name = estimator.__class__.__name__ + if sk_version >= parse_version("0.24"): + Estimator = estimator + else: + Estimator = estimator.__class__ sklearn_checks.check_parameters_default_constructible(name, Estimator)