From 360309ce5a9e5ea715c1b34df87cfc955e79a521 Mon Sep 17 00:00:00 2001 From: Jan Stiborek Date: Sat, 7 Nov 2020 21:54:51 +0100 Subject: [PATCH 01/19] migrated implementation from dask/dask-lightgbm --- .ci/test.sh | 2 +- python-package/lightgbm/dask.py | 299 +++++++++++++++++++++++++ python-package/setup.py | 6 + tests/python_package_test/test_dask.py | 205 +++++++++++++++++ 4 files changed, 511 insertions(+), 1 deletion(-) create mode 100644 python-package/lightgbm/dask.py create mode 100644 tests/python_package_test/test_dask.py diff --git a/.ci/test.sh b/.ci/test.sh index 4cd181790ad7..0c28943d2458 100755 --- a/.ci/test.sh +++ b/.ci/test.sh @@ -70,7 +70,7 @@ if [[ $TASK == "if-else" ]]; then exit 0 fi -conda install -q -y -n $CONDA_ENV joblib matplotlib numpy pandas psutil pytest python-graphviz scikit-learn scipy +conda install -q -y -n $CONDA_ENV joblib matplotlib numpy pandas psutil pytest python-graphviz scikit-learn scipy dask distributed dask-ml if [[ $OS_NAME == "macos" ]] && [[ $COMPILER == "clang" ]]; then # fix "OMP: Error #15: Initializing libiomp5.dylib, but found libomp.dylib already initialized." (OpenMP library conflict due to conda's MKL) diff --git a/python-package/lightgbm/dask.py b/python-package/lightgbm/dask.py new file mode 100644 index 000000000000..e70983040eca --- /dev/null +++ b/python-package/lightgbm/dask.py @@ -0,0 +1,299 @@ +"""Distributed training with LightGBM and Dask.distributed. + +This module enables you to perform distributed training with LightGBM on Dask.Array and Dask.DataFrame collections. +It is based on dask-xgboost package. +""" +import logging +from collections import defaultdict + +try: + from urllib.parse import urlparse +except ImportError: + from urlparse import urlparse + +import dask.array as da +import dask.dataframe as dd +import lightgbm +import numpy as np +import pandas as pd +from dask import delayed +from dask.distributed import wait, default_client, get_worker +from lightgbm.basic import _safe_call, _LIB +from toolz import first, assoc + +try: + import scipy.sparse as ss +except ImportError: + ss = False + +logger = logging.getLogger(__name__) + + +def _parse_host_port(address): + parsed = urlparse(address) + return parsed.hostname, parsed.port + + +def build_network_params(worker_addresses, local_worker_ip, local_listen_port, time_out): + """Build network parameters suiltable for LightGBM C backend. + + Parameters + ---------- + worker_addresses : iterable of str - collection of worker addresses in `://:port` format + local_worker_ip : str + local_listen_port : int + listen_time_out : int + + Returns + ------- + params: dict + """ + addr_port_map = {addr: (local_listen_port + i) for i, addr in enumerate(worker_addresses)} + params = { + 'machines': ','.join('%s:%d' % (_parse_host_port(addr)[0], port) for addr, port in addr_port_map.items()), + 'local_listen_port': addr_port_map[local_worker_ip], + 'time_out': time_out, + 'num_machines': len(addr_port_map) + } + return params + + +def _concat(seq): + if isinstance(seq[0], np.ndarray): + return np.concatenate(seq, axis=0) + elif isinstance(seq[0], (pd.DataFrame, pd.Series)): + return pd.concat(seq, axis=0) + elif ss and isinstance(seq[0], ss.spmatrix): + return ss.vstack(seq, format='csr') + else: + raise TypeError('Data must be one of: numpy arrays, pandas dataframes, sparse matrices (from scipy). Got %s.' % str(type(seq[0]))) + + +def _train_part(params, model_factory, list_of_parts, worker_addresses, return_model, local_listen_port=12400, + time_out=120, **kwargs): + network_params = build_network_params(worker_addresses, get_worker().address, local_listen_port, time_out) + params.update(network_params) + + # Concatenate many parts into one + parts = tuple(zip(*list_of_parts)) + data = _concat(parts[0]) + label = _concat(parts[1]) + weight = _concat(parts[2]) if len(parts) == 3 else None + + try: + model = model_factory(**params) + model.fit(data, label, sample_weight=weight, **kwargs) + finally: + _safe_call(_LIB.LGBM_NetworkFree()) + + return model if return_model else None + + +def _split_to_parts(data, is_matrix): + parts = data.to_delayed() + if isinstance(parts, np.ndarray): + assert (parts.shape[1] == 1) if is_matrix else (parts.ndim == 1 or parts.shape[1] == 1) + parts = parts.flatten().tolist() + return parts + + +def train(client, data, label, params, model_factory, weight=None, **kwargs): + """Inner train routine. + + Parameters + ---------- + client: dask.Client - client + X : dask array of shape = [n_samples, n_features] + Input feature matrix. + y : dask array of shape = [n_samples] + The target values (class labels in classification, real numbers in regression). + params : dict + model_factory : lightgbm.LGBMClassifier or lightgbm.LGBMRegressor class + sample_weight : array-like of shape = [n_samples] or None, optional (default=None) + Weights of training data. + """ + # Split arrays/dataframes into parts. Arrange parts into tuples to enforce co-locality + data_parts = _split_to_parts(data, is_matrix=True) + label_parts = _split_to_parts(label, is_matrix=False) + if weight is None: + parts = list(map(delayed, zip(data_parts, label_parts))) + else: + weight_parts = _split_to_parts(weight, is_matrix=False) + parts = list(map(delayed, zip(data_parts, label_parts, weight_parts))) + + # Start computation in the background + parts = client.compute(parts) + wait(parts) + + for part in parts: + if part.status == 'error': + return part # trigger error locally + + # Find locations of all parts and map them to particular Dask workers + key_to_part_dict = dict([(part.key, part) for part in parts]) + who_has = client.who_has(parts) + worker_map = defaultdict(list) + for key, workers in who_has.items(): + worker_map[first(workers)].append(key_to_part_dict[key]) + + master_worker = first(worker_map) + worker_ncores = client.ncores() + + if 'tree_learner' not in params or params['tree_learner'].lower() not in {'data', 'feature', 'voting'}: + logger.warning('Parameter tree_learner not set or set to incorrect value ' + '(%s), using "data" as default', params.get("tree_learner", None)) + params['tree_learner'] = 'data' + + # Tell each worker to train on the parts that it has locally + futures_classifiers = [client.submit(_train_part, + model_factory=model_factory, + params=assoc(params, 'num_threads', worker_ncores[worker]), + list_of_parts=list_of_parts, + worker_addresses=list(worker_map.keys()), + local_listen_port=params.get('local_listen_port', 12400), + time_out=params.get('time_out', 120), + return_model=(worker == master_worker), + **kwargs) + for worker, list_of_parts in worker_map.items()] + + results = client.gather(futures_classifiers) + results = [v for v in results if v] + return results[0] + + +def _predict_part(part, model, proba, **kwargs): + data = part.values if isinstance(part, pd.DataFrame) else part + + if data.shape[0] == 0: + result = np.array([]) + elif proba: + result = model.predict_proba(data, **kwargs) + else: + result = model.predict(data, **kwargs) + + if isinstance(part, pd.DataFrame): + if proba: + result = pd.DataFrame(result, index=part.index) + else: + result = pd.Series(result, index=part.index, name='predictions') + + return result + + +def predict(client, model, data, proba=False, dtype=np.float32, **kwargs): + """Inner predict routine. + + Parameters + ---------- + client: dask.Client - client + model : + data : dask array of shape = [n_samples, n_features] + Input feature matrix. + proba : bool + Should method return results of predict_proba (proba == True) or predict (proba == False) + dtype : np.dtype + Dtype of the output + kwargs : other parameters passed to predict or predict_proba method + """ + if isinstance(data, dd._Frame): + return data.map_partitions(_predict_part, model=model, proba=proba, **kwargs).values + elif isinstance(data, da.Array): + if proba: + kwargs['chunks'] = (data.chunks[0], (model.n_classes_,)) + else: + kwargs['drop_axis'] = 1 + return data.map_blocks(_predict_part, model=model, proba=proba, dtype=dtype, **kwargs) + else: + raise TypeError('Data must be either Dask array or dataframe. Got %s.' % str(type(data))) + + +class _LGBMModel: + + @staticmethod + def _copy_extra_params(source, dest): + params = source.get_params() + attributes = source.__dict__ + extra_param_names = set(attributes.keys()).difference(params.keys()) + for name in extra_param_names: + setattr(dest, name, attributes[name]) + + +class LGBMClassifier(_LGBMModel, lightgbm.LGBMClassifier): + """Distributed version of lightgbm.LGBMClassifier.""" + + def fit(self, X, y=None, sample_weight=None, client=None, **kwargs): + """Docstring is inherited from the LGBMModel.""" + if client is None: + client = default_client() + + model_factory = lightgbm.LGBMClassifier + params = self.get_params(True) + model = train(client, X, y, params, model_factory, sample_weight, **kwargs) + + self.set_params(**model.get_params()) + self._copy_extra_params(model, self) + + return self + fit.__doc__ = lightgbm.LGBMClassifier.fit.__doc__ + + def predict(self, X, client=None, **kwargs): + """Docstring is inherited from the lightgbm.LGBMClassifier.predict.""" + if client is None: + client = default_client() + return predict(client, self.to_local(), X, dtype=self.classes_.dtype, **kwargs) + predict.__doc__ = lightgbm.LGBMClassifier.predict.__doc__ + + def predict_proba(self, X, client=None, **kwargs): + """Docstring is inherited from the lightgbm.LGBMClassifier.predict_proba.""" + if client is None: + client = default_client() + return predict(client, self.to_local(), X, proba=True, **kwargs) + predict_proba.__doc__ = lightgbm.LGBMClassifier.predict_proba.__doc__ + + def to_local(self): + """Create regular version of lightgbm.LGBMClassifier from the distributed version. + + Returns + ------- + model : lightgbm.LGBMClassifier + """ + model = lightgbm.LGBMClassifier(**self.get_params()) + self._copy_extra_params(self, model) + return model + + +class LGBMRegressor(_LGBMModel, lightgbm.LGBMRegressor): + """Docstring is inherited from the lightgbm.LGBMRegressor.""" + + def fit(self, X, y=None, sample_weight=None, client=None, **kwargs): + """Docstring is inherited from the lightgbm.LGBMRegressor.fit.""" + if client is None: + client = default_client() + + model_factory = lightgbm.LGBMRegressor + params = self.get_params(True) + model = train(client, X, y, params, model_factory, sample_weight, **kwargs) + + self.set_params(**model.get_params()) + self._copy_extra_params(model, self) + + return self + fit.__doc__ = lightgbm.LGBMRegressor.fit.__doc__ + + def predict(self, X, client=None, **kwargs): + """Docstring is inherited from the lightgbm.LGBMRegressor.predict.""" + if client is None: + client = default_client() + return predict(client, self.to_local(), X, **kwargs) + predict.__doc__ = lightgbm.LGBMRegressor.predict.__doc__ + + def to_local(self): + """Create regular version of lightgbm.LGBMRegressor from the distributed version. + + Returns + ------- + model : lightgbm.LGBMRegressor + """ + model = lightgbm.LGBMRegressor(**self.get_params()) + self._copy_extra_params(self, model) + return model diff --git a/python-package/setup.py b/python-package/setup.py index a2e8a0cf3560..cdcd92d65ecb 100644 --- a/python-package/setup.py +++ b/python-package/setup.py @@ -340,6 +340,12 @@ def run(self): 'scipy', 'scikit-learn!=0.22.0' ], + extras_requires={ + 'dask': [ + 'dask>=0.16.0', + 'distributed>=1.15.2' + ], + }, maintainer='Guolin Ke', maintainer_email='guolin.ke@microsoft.com', zip_safe=False, diff --git a/tests/python_package_test/test_dask.py b/tests/python_package_test/test_dask.py new file mode 100644 index 000000000000..50b18dd82bbf --- /dev/null +++ b/tests/python_package_test/test_dask.py @@ -0,0 +1,205 @@ +import dask.array as da +import dask.dataframe as dd +import lightgbm +import numpy as np +import pandas as pd +import pytest +import scipy.sparse +from dask.array.utils import assert_eq +from dask_ml.metrics import accuracy_score, r2_score + +from distributed.utils_test import client, cluster_fixture, loop, gen_cluster # noqa +from sklearn.datasets import make_blobs, make_regression +from sklearn.metrics import confusion_matrix + +import lightgbm.dask as dlgbm + +data_output = ['array', 'scipy_csr_matrix', 'dataframe'] +data_centers = [[[-4, -4], [4, 4]], [[-4, -4], [4, 4], [-4, 4]]] + + +@pytest.fixture() +def listen_port(): + listen_port.port += 10 + return listen_port.port + + +listen_port.port = 13000 + + +def _create_data(objective, n_samples=100, centers=2, output='array', chunk_size=50): + if objective == 'classification': + X, y = make_blobs(n_samples=n_samples, centers=centers, random_state=42) + elif objective == 'regression': + X, y = make_regression(n_samples=n_samples, random_state=42) + else: + raise ValueError(objective) + rnd = np.random.RandomState(42) + w = rnd.rand(X.shape[0]) * 0.01 + + if output == 'array': + dX = da.from_array(X, (chunk_size, X.shape[1])) + dy = da.from_array(y, chunk_size) + dw = da.from_array(w, chunk_size) + elif output == 'dataframe': + X_df = pd.DataFrame(X, columns=['feature_%d' % i for i in range(X.shape[1])]) + y_df = pd.Series(y, name='target') + dX = dd.from_pandas(X_df, chunksize=chunk_size) + dy = dd.from_pandas(y_df, chunksize=chunk_size) + dw = dd.from_array(w, chunksize=chunk_size) + elif output == 'scipy_csr_matrix': + dX = da.from_array(X, chunks=(chunk_size, X.shape[1])).map_blocks(scipy.sparse.csr_matrix) + dy = da.from_array(y, chunks=chunk_size) + dw = da.from_array(w, chunk_size) + + return X, y, w, dX, dy, dw + + +@pytest.mark.parametrize('output', data_output) +@pytest.mark.parametrize('centers', data_centers) +def test_classifier(output, centers, client, listen_port): # noqa + X, y, w, dX, dy, dw = _create_data('classification', output=output, centers=centers) + + a = dlgbm.LGBMClassifier(time_out=5, local_listen_port=listen_port) + a = a.fit(dX, dy, sample_weight=dw, client=client) + p1 = a.predict(dX, client=client) + s1 = accuracy_score(dy, p1) + p1 = p1.compute() + + b = lightgbm.LGBMClassifier() + b.fit(X, y, sample_weight=w) + p2 = b.predict(X) + s2 = b.score(X, y) + print(confusion_matrix(y, p1)) + print(confusion_matrix(y, p2)) + + assert_eq(s1, s2) + print(s1) + + assert_eq(p1, p2) + assert_eq(y, p1) + assert_eq(y, p2) + + +@pytest.mark.parametrize('output', data_output) +@pytest.mark.parametrize('centers', data_centers) +def test_classifier_proba(output, centers, client, listen_port): # noqa + X, y, w, dX, dy, dw = _create_data('classification', output=output, centers=centers) + + a = dlgbm.LGBMClassifier(time_out=5, local_listen_port=listen_port) + a = a.fit(dX, dy, sample_weight=dw, client=client) + p1 = a.predict_proba(dX, client=client) + p1 = p1.compute() + + b = lightgbm.LGBMClassifier() + b.fit(X, y, sample_weight=w) + p2 = b.predict_proba(X) + + assert_eq(p1, p2, atol=0.3) + + +def test_classifier_local_predict(client, listen_port): # noqa + X, y, w, dX, dy, dw = _create_data('classification', output='array') + + a = dlgbm.LGBMClassifier(time_out=5, local_listen_port=listen_port) + a = a.fit(dX, dy, sample_weight=dw, client=client) + p1 = a.to_local().predict(dX) + + b = lightgbm.LGBMClassifier() + b.fit(X, y, sample_weight=w) + p2 = b.predict(X) + + assert_eq(p1, p2) + assert_eq(y, p1) + assert_eq(y, p2) + + +@pytest.mark.parametrize('output', data_output) +def test_regressor(output, client, listen_port): # noqa + X, y, w, dX, dy, dw = _create_data('regression', output=output) + + a = dlgbm.LGBMRegressor(time_out=5, local_listen_port=listen_port, seed=42) + a = a.fit(dX, dy, client=client, sample_weight=dw) + p1 = a.predict(dX, client=client) + if output != 'dataframe': + s1 = r2_score(dy, p1) + p1 = p1.compute() + + b = lightgbm.LGBMRegressor(seed=42) + b.fit(X, y, sample_weight=w) + s2 = b.score(X, y) + p2 = b.predict(X) + + # Scores should be the same + if output != 'dataframe': + assert_eq(s1, s2, atol=.01) + + # Predictions should be roughly the same + assert_eq(y, p1, rtol=1., atol=50.) + assert_eq(y, p2, rtol=1., atol=50.) + + +@pytest.mark.parametrize('output', data_output) +@pytest.mark.parametrize('alpha', [.1, .5, .9]) +def test_regressor_quantile(output, client, listen_port, alpha): # noqa + X, y, w, dX, dy, dw = _create_data('regression', output=output) + + a = dlgbm.LGBMRegressor(local_listen_port=listen_port, seed=42, objective='quantile', alpha=alpha) + a = a.fit(dX, dy, client=client, sample_weight=dw) + p1 = a.predict(dX, client=client).compute() + q1 = np.count_nonzero(y < p1) / y.shape[0] + + b = lightgbm.LGBMRegressor(seed=42, objective='quantile', alpha=alpha) + b.fit(X, y, sample_weight=w) + p2 = b.predict(X) + q2 = np.count_nonzero(y < p2) / y.shape[0] + + # Quantiles should be right + np.isclose(q1, alpha, atol=.1) + np.isclose(q2, alpha, atol=.1) + + +def test_regressor_local_predict(client, listen_port): # noqa + X, y, w, dX, dy, dw = _create_data('regression', output='array') + + a = dlgbm.LGBMRegressor(local_listen_port=listen_port, seed=42) + a = a.fit(dX, dy, sample_weight=dw, client=client) + p1 = a.predict(dX) + p2 = a.to_local().predict(X) + s1 = r2_score(dy, p1) + p1 = p1.compute() + s2 = a.to_local().score(X, y) + print(s1) + + # Predictions and scores should be the same + assert_eq(p1, p2) + np.isclose(s1, s2) + + +def test_build_network_params(): + workers_ips = [ + 'tcp://192.168.0.1:34545', + 'tcp://192.168.0.2:34346', + 'tcp://192.168.0.3:34347' + ] + + params = dlgbm.build_network_params(workers_ips, 'tcp://192.168.0.2:34346', 12400, 120) + exp_params = { + 'machines': '192.168.0.1:12400,192.168.0.2:12401,192.168.0.3:12402', + 'local_listen_port': 12401, + 'num_machines': len(workers_ips), + 'time_out': 120 + } + assert exp_params == params + + +@gen_cluster(client=True, timeout=None) +def test_errors(c, s, a, b): + def f(part): + raise Exception('foo') + + df = dd.demo.make_timeseries() + df = df.map_partitions(f, meta=df._meta) + with pytest.raises(Exception) as info: + yield dlgbm.train(c, df, df.x, params={}, model_factory=lightgbm.LGBMClassifier) + assert 'foo' in str(info.value) From 10c594cfe194b6add86840ac0daaa20b16cf9833 Mon Sep 17 00:00:00 2001 From: Jan Stiborek Date: Sat, 7 Nov 2020 22:38:54 +0100 Subject: [PATCH 02/19] relaxed tests --- tests/python_package_test/test_dask.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/python_package_test/test_dask.py b/tests/python_package_test/test_dask.py index 50b18dd82bbf..289746dfb0ad 100644 --- a/tests/python_package_test/test_dask.py +++ b/tests/python_package_test/test_dask.py @@ -135,7 +135,7 @@ def test_regressor(output, client, listen_port): # noqa assert_eq(s1, s2, atol=.01) # Predictions should be roughly the same - assert_eq(y, p1, rtol=1., atol=50.) + assert_eq(y, p1, rtol=1., atol=100.) assert_eq(y, p2, rtol=1., atol=50.) From 094c4cd6503b862056d55b662e37c0c0b4ff5d54 Mon Sep 17 00:00:00 2001 From: Jan Stiborek Date: Sat, 7 Nov 2020 22:58:14 +0100 Subject: [PATCH 03/19] tests skipped in case that MPI is used --- tests/python_package_test/test_dask.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/tests/python_package_test/test_dask.py b/tests/python_package_test/test_dask.py index 289746dfb0ad..b2e5e1dfa4f3 100644 --- a/tests/python_package_test/test_dask.py +++ b/tests/python_package_test/test_dask.py @@ -1,3 +1,6 @@ +import os +import sys + import dask.array as da import dask.dataframe as dd import lightgbm @@ -114,6 +117,7 @@ def test_classifier_local_predict(client, listen_port): # noqa assert_eq(y, p2) +@pytest.mark.skipif(os.getenv("TASK", "") == "mpi", reason="Fails to run with MPI interface") @pytest.mark.parametrize('output', data_output) def test_regressor(output, client, listen_port): # noqa X, y, w, dX, dy, dw = _create_data('regression', output=output) From 07b7ef412b4a289a6ad8aef2b4dfca1deef343e5 Mon Sep 17 00:00:00 2001 From: Jan Stiborek Date: Fri, 13 Nov 2020 21:29:41 +0100 Subject: [PATCH 04/19] fixed python 2.7 import + tests disabled on windows --- python-package/lightgbm/{dask.py => dask_distributed.py} | 6 ++++-- python-package/setup.py | 7 ++++--- .../{test_dask.py => test_dask_distributed.py} | 9 +++++++-- 3 files changed, 15 insertions(+), 7 deletions(-) rename python-package/lightgbm/{dask.py => dask_distributed.py} (98%) rename tests/python_package_test/{test_dask.py => test_dask_distributed.py} (95%) diff --git a/python-package/lightgbm/dask.py b/python-package/lightgbm/dask_distributed.py similarity index 98% rename from python-package/lightgbm/dask.py rename to python-package/lightgbm/dask_distributed.py index e70983040eca..21580b7999fa 100644 --- a/python-package/lightgbm/dask.py +++ b/python-package/lightgbm/dask_distributed.py @@ -11,8 +11,10 @@ except ImportError: from urlparse import urlparse -import dask.array as da -import dask.dataframe as dd +from dask import array as da +from dask import dataframe as dd +# import dask.array as da +# import dask.dataframe as dd import lightgbm import numpy as np import pandas as pd diff --git a/python-package/setup.py b/python-package/setup.py index cdcd92d65ecb..869fdb435761 100644 --- a/python-package/setup.py +++ b/python-package/setup.py @@ -340,10 +340,11 @@ def run(self): 'scipy', 'scikit-learn!=0.22.0' ], - extras_requires={ + extras_require={ 'dask': [ - 'dask>=0.16.0', - 'distributed>=1.15.2' + 'dask[array]>=1.0.0', + 'dask[dataframe]>=1.0.0' + 'dask[distributed]>=1.0.0' ], }, maintainer='Guolin Ke', diff --git a/tests/python_package_test/test_dask.py b/tests/python_package_test/test_dask_distributed.py similarity index 95% rename from tests/python_package_test/test_dask.py rename to tests/python_package_test/test_dask_distributed.py index b2e5e1dfa4f3..035d1eb3d59a 100644 --- a/tests/python_package_test/test_dask.py +++ b/tests/python_package_test/test_dask_distributed.py @@ -9,17 +9,22 @@ import pytest import scipy.sparse from dask.array.utils import assert_eq -from dask_ml.metrics import accuracy_score, r2_score +try: + from dask_ml.metrics import accuracy_score, r2_score +except ImportError: + from sklearn.metrics import accuracy_score, r2_score from distributed.utils_test import client, cluster_fixture, loop, gen_cluster # noqa from sklearn.datasets import make_blobs, make_regression from sklearn.metrics import confusion_matrix -import lightgbm.dask as dlgbm +import lightgbm.dask_distributed as dlgbm data_output = ['array', 'scipy_csr_matrix', 'dataframe'] data_centers = [[[-4, -4], [4, 4]], [[-4, -4], [4, 4], [-4, 4]]] +pytestmark = pytest.mark.skipif(sys.platform == "win32", reason="Windows is currently not supported") + @pytest.fixture() def listen_port(): From 1a3954e770f83e1be183029d719a52af87a2b5a1 Mon Sep 17 00:00:00 2001 From: Jan Stiborek Date: Sat, 14 Nov 2020 12:08:37 +0100 Subject: [PATCH 05/19] python < 3.6 is not supported in tests --- tests/python_package_test/test_dask_distributed.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/tests/python_package_test/test_dask_distributed.py b/tests/python_package_test/test_dask_distributed.py index 035d1eb3d59a..a33a78c342a1 100644 --- a/tests/python_package_test/test_dask_distributed.py +++ b/tests/python_package_test/test_dask_distributed.py @@ -23,7 +23,10 @@ data_output = ['array', 'scipy_csr_matrix', 'dataframe'] data_centers = [[[-4, -4], [4, 4]], [[-4, -4], [4, 4], [-4, 4]]] -pytestmark = pytest.mark.skipif(sys.platform == "win32", reason="Windows is currently not supported") +pytestmark = [ + pytest.mark.skipif(sys.platform == "win32", reason="Windows is currently not supported"), + pytest.mark.skipif(sys.version_info < (3, 6), reason="Only python 3.6 is supported") +] @pytest.fixture() From 38e4b5178b9bcedc5db4b5477b75f6d0fa76ef0b Mon Sep 17 00:00:00 2001 From: Jan Stiborek Date: Sat, 14 Nov 2020 19:18:37 +0100 Subject: [PATCH 06/19] tests enabled only for linux --- tests/python_package_test/test_dask_distributed.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/python_package_test/test_dask_distributed.py b/tests/python_package_test/test_dask_distributed.py index a33a78c342a1..92f1d29bf3f6 100644 --- a/tests/python_package_test/test_dask_distributed.py +++ b/tests/python_package_test/test_dask_distributed.py @@ -24,8 +24,8 @@ data_centers = [[[-4, -4], [4, 4]], [[-4, -4], [4, 4], [-4, 4]]] pytestmark = [ - pytest.mark.skipif(sys.platform == "win32", reason="Windows is currently not supported"), - pytest.mark.skipif(sys.version_info < (3, 6), reason="Only python 3.6 is supported") + pytest.mark.skipif(sys.platform != "linux", reason="Only linux is currently supported"), + pytest.mark.skipif(sys.version_info < (3, 6), reason="Only python>=3.6 is supported") ] From 4ca3b22ccecaa0c7b0382d3851b242a63756d7e2 Mon Sep 17 00:00:00 2001 From: Jan Stiborek Date: Sat, 14 Nov 2020 20:53:44 +0100 Subject: [PATCH 07/19] tests disabled for mpi interface --- tests/python_package_test/test_dask_distributed.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/python_package_test/test_dask_distributed.py b/tests/python_package_test/test_dask_distributed.py index 92f1d29bf3f6..342c17b6e32a 100644 --- a/tests/python_package_test/test_dask_distributed.py +++ b/tests/python_package_test/test_dask_distributed.py @@ -25,7 +25,8 @@ pytestmark = [ pytest.mark.skipif(sys.platform != "linux", reason="Only linux is currently supported"), - pytest.mark.skipif(sys.version_info < (3, 6), reason="Only python>=3.6 is supported") + pytest.mark.skipif(sys.version_info < (3, 6), reason="Only python>=3.6 is supported"), + pytest.mark.skipif(os.getenv("TASK", "") == "mpi", reason="Fails to run with MPI interface") ] @@ -125,7 +126,6 @@ def test_classifier_local_predict(client, listen_port): # noqa assert_eq(y, p2) -@pytest.mark.skipif(os.getenv("TASK", "") == "mpi", reason="Fails to run with MPI interface") @pytest.mark.parametrize('output', data_output) def test_regressor(output, client, listen_port): # noqa X, y, w, dX, dy, dw = _create_data('regression', output=output) From c22310c9359c7fafa1b7aed383e2ad9eefaf6731 Mon Sep 17 00:00:00 2001 From: Jan Stiborek Date: Sat, 14 Nov 2020 21:03:21 +0100 Subject: [PATCH 08/19] dask version pinned to >= 2.0 --- python-package/setup.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/python-package/setup.py b/python-package/setup.py index 869fdb435761..829fab14ca6f 100644 --- a/python-package/setup.py +++ b/python-package/setup.py @@ -342,9 +342,9 @@ def run(self): ], extras_require={ 'dask': [ - 'dask[array]>=1.0.0', - 'dask[dataframe]>=1.0.0' - 'dask[distributed]>=1.0.0' + 'dask[array]>=2.0.0', + 'dask[dataframe]>=2.0.0' + 'dask[distributed]>=2.0.0' ], }, maintainer='Guolin Ke', From 9fd7f594db92072973dd3bcdddd26cd0b1352252 Mon Sep 17 00:00:00 2001 From: Jan Stiborek Date: Sun, 15 Nov 2020 12:02:59 +0100 Subject: [PATCH 09/19] added @jameslamb as code owner --- .github/CODEOWNERS | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/.github/CODEOWNERS b/.github/CODEOWNERS index 7164a84e8ca5..a492c6532ff6 100644 --- a/.github/CODEOWNERS +++ b/.github/CODEOWNERS @@ -30,6 +30,10 @@ R-package/ @Laurae2 @jameslamb # Python code python-package/ @StrikerRUS @chivee @wxchan @henry0312 +# Dask integration +python-package/lightgbm/dask_distributed.py @jameslamb +tests/python_package_test/test_dask_distributed.py @jameslamb + # helpers helpers/ @StrikerRUS @guolinke From 5bf6fcf193911d2283aa15ab2e5f3ea688d081c4 Mon Sep 17 00:00:00 2001 From: Jan Stiborek Date: Sun, 15 Nov 2020 12:15:59 +0100 Subject: [PATCH 10/19] added missing pandas dependency --- python-package/setup.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/python-package/setup.py b/python-package/setup.py index 829fab14ca6f..4f41626c7d17 100644 --- a/python-package/setup.py +++ b/python-package/setup.py @@ -344,7 +344,8 @@ def run(self): 'dask': [ 'dask[array]>=2.0.0', 'dask[dataframe]>=2.0.0' - 'dask[distributed]>=2.0.0' + 'dask[distributed]>=2.0.0', + 'pandas' ], }, maintainer='Guolin Ke', From ef072111ac0c02221bc1071cf456306bf99536a9 Mon Sep 17 00:00:00 2001 From: Jan Stiborek Date: Sun, 15 Nov 2020 12:17:21 +0100 Subject: [PATCH 11/19] code refactoring, removed code duplication - lightgbm.dask.LGBMClassifier.fit is the same as lightgbm.dask.LGBMRegressor.fit --- .../lightgbm/{dask_distributed.py => dask.py} | 138 ++++++++---------- ...{test_dask_distributed.py => test_dask.py} | 45 +++--- 2 files changed, 79 insertions(+), 104 deletions(-) rename python-package/lightgbm/{dask_distributed.py => dask.py} (74%) rename tests/python_package_test/{test_dask_distributed.py => test_dask.py} (81%) diff --git a/python-package/lightgbm/dask_distributed.py b/python-package/lightgbm/dask.py similarity index 74% rename from python-package/lightgbm/dask_distributed.py rename to python-package/lightgbm/dask.py index 21580b7999fa..c8f14c7f7c0f 100644 --- a/python-package/lightgbm/dask_distributed.py +++ b/python-package/lightgbm/dask.py @@ -5,23 +5,20 @@ """ import logging from collections import defaultdict +from urllib.parse import urlparse -try: - from urllib.parse import urlparse -except ImportError: - from urlparse import urlparse - -from dask import array as da -from dask import dataframe as dd -# import dask.array as da -# import dask.dataframe as dd -import lightgbm import numpy as np import pandas as pd +from dask import array as da +from dask import dataframe as dd from dask import delayed -from dask.distributed import wait, default_client, get_worker -from lightgbm.basic import _safe_call, _LIB -from toolz import first, assoc +from dask.distributed import default_client, get_worker, wait +from toolz import assoc, first + +import lightgbm +from .basic import _LIB, _safe_call +from .sklearn import LGBMClassifier as LocalLGBMClassifier, LGBMRegressor as LocalLGBMRegressor +from .compat import _LGBMModelBase try: import scipy.sparse as ss @@ -36,7 +33,7 @@ def _parse_host_port(address): return parsed.hostname, parsed.port -def build_network_params(worker_addresses, local_worker_ip, local_listen_port, time_out): +def _build_network_params(worker_addresses, local_worker_ip, local_listen_port, time_out): """Build network parameters suiltable for LightGBM C backend. Parameters @@ -73,7 +70,7 @@ def _concat(seq): def _train_part(params, model_factory, list_of_parts, worker_addresses, return_model, local_listen_port=12400, time_out=120, **kwargs): - network_params = build_network_params(worker_addresses, get_worker().address, local_listen_port, time_out) + network_params = _build_network_params(worker_addresses, get_worker().address, local_listen_port, time_out) params.update(network_params) # Concatenate many parts into one @@ -99,7 +96,7 @@ def _split_to_parts(data, is_matrix): return parts -def train(client, data, label, params, model_factory, weight=None, **kwargs): +def _train(client, data, label, params, model_factory, weight=None, **kwargs): """Inner train routine. Parameters @@ -182,7 +179,7 @@ def _predict_part(part, model, proba, **kwargs): return result -def predict(client, model, data, proba=False, dtype=np.float32, **kwargs): +def _predict(model, data, proba=False, dtype=np.float32, **kwargs): """Inner predict routine. Parameters @@ -209,93 +206,74 @@ def predict(client, model, data, proba=False, dtype=np.float32, **kwargs): raise TypeError('Data must be either Dask array or dataframe. Got %s.' % str(type(data))) -class _LGBMModel: - - @staticmethod - def _copy_extra_params(source, dest): - params = source.get_params() - attributes = source.__dict__ - extra_param_names = set(attributes.keys()).difference(params.keys()) - for name in extra_param_names: - setattr(dest, name, attributes[name]) - - -class LGBMClassifier(_LGBMModel, lightgbm.LGBMClassifier): - """Distributed version of lightgbm.LGBMClassifier.""" +class _LGBMModel(_LGBMModelBase): - def fit(self, X, y=None, sample_weight=None, client=None, **kwargs): - """Docstring is inherited from the LGBMModel.""" + def __init__(self, model_factory, client=None) -> None: if client is None: client = default_client() + self.client = client + self.model_factory = model_factory - model_factory = lightgbm.LGBMClassifier + def fit(self, X, y=None, sample_weight=None, **kwargs): + """Docstring is inherited from the appropriate model_factory.""" params = self.get_params(True) - model = train(client, X, y, params, model_factory, sample_weight, **kwargs) + model = _train(self.client, X, y, params, self.model_factory, sample_weight, **kwargs) self.set_params(**model.get_params()) self._copy_extra_params(model, self) return self - fit.__doc__ = lightgbm.LGBMClassifier.fit.__doc__ - - def predict(self, X, client=None, **kwargs): - """Docstring is inherited from the lightgbm.LGBMClassifier.predict.""" - if client is None: - client = default_client() - return predict(client, self.to_local(), X, dtype=self.classes_.dtype, **kwargs) - predict.__doc__ = lightgbm.LGBMClassifier.predict.__doc__ - - def predict_proba(self, X, client=None, **kwargs): - """Docstring is inherited from the lightgbm.LGBMClassifier.predict_proba.""" - if client is None: - client = default_client() - return predict(client, self.to_local(), X, proba=True, **kwargs) - predict_proba.__doc__ = lightgbm.LGBMClassifier.predict_proba.__doc__ def to_local(self): - """Create regular version of lightgbm.LGBMClassifier from the distributed version. + """Create regular version of lightgbm.LGBMRegressor from the distributed version. Returns ------- - model : lightgbm.LGBMClassifier + model : lightgbm.LGBMRegressor """ - model = lightgbm.LGBMClassifier(**self.get_params()) + model = self.model_factory(**self.get_params()) self._copy_extra_params(self, model) return model + @staticmethod + def _copy_extra_params(source, dest): + params = source.get_params() + attributes = source.__dict__ + extra_param_names = set(attributes.keys()).difference(params.keys()) + for name in extra_param_names: + setattr(dest, name, attributes[name]) + -class LGBMRegressor(_LGBMModel, lightgbm.LGBMRegressor): - """Docstring is inherited from the lightgbm.LGBMRegressor.""" +class LGBMClassifier(_LGBMModel, LocalLGBMClassifier): + """Distributed version of lightgbm.LGBMClassifier.""" - def fit(self, X, y=None, sample_weight=None, client=None, **kwargs): - """Docstring is inherited from the lightgbm.LGBMRegressor.fit.""" - if client is None: - client = default_client() + def __init__(self, client=None, **kwargs) -> None: + super().__init__(LocalLGBMClassifier, client) + super(_LGBMModel, self).__init__(**kwargs) - model_factory = lightgbm.LGBMRegressor - params = self.get_params(True) - model = train(client, X, y, params, model_factory, sample_weight, **kwargs) + _LGBMModel.fit.__doc__ = LocalLGBMClassifier.fit.__doc__ - self.set_params(**model.get_params()) - self._copy_extra_params(model, self) + def predict(self, X, **kwargs): + """Docstring is inherited from the lightgbm.LGBMClassifier.predict.""" + return _predict(self.to_local(), X, dtype=self.classes_.dtype, **kwargs) + predict.__doc__ = LocalLGBMClassifier.predict.__doc__ - return self - fit.__doc__ = lightgbm.LGBMRegressor.fit.__doc__ + def predict_proba(self, X, **kwargs): + """Docstring is inherited from the lightgbm.LGBMClassifier.predict_proba.""" + return _predict(self.to_local(), X, proba=True, **kwargs) + predict_proba.__doc__ = LocalLGBMClassifier.predict_proba.__doc__ - def predict(self, X, client=None, **kwargs): - """Docstring is inherited from the lightgbm.LGBMRegressor.predict.""" - if client is None: - client = default_client() - return predict(client, self.to_local(), X, **kwargs) - predict.__doc__ = lightgbm.LGBMRegressor.predict.__doc__ - def to_local(self): - """Create regular version of lightgbm.LGBMRegressor from the distributed version. +class LGBMRegressor(_LGBMModel, LocalLGBMRegressor): + """Docstring is inherited from the lightgbm.LGBMRegressor.""" - Returns - ------- - model : lightgbm.LGBMRegressor - """ - model = lightgbm.LGBMRegressor(**self.get_params()) - self._copy_extra_params(self, model) - return model + def __init__(self, client=None, **kwargs) -> None: + super().__init__(LocalLGBMRegressor, client) + super(LocalLGBMRegressor, self).__init__(**kwargs) + + _LGBMModel.fit.__doc__ = LocalLGBMRegressor.fit.__doc__ + + def predict(self, X, **kwargs): + """Docstring is inherited from the lightgbm.LGBMRegressor.predict.""" + return _predict(self.to_local(), X, **kwargs) + predict.__doc__ = LocalLGBMRegressor.predict.__doc__ diff --git a/tests/python_package_test/test_dask_distributed.py b/tests/python_package_test/test_dask.py similarity index 81% rename from tests/python_package_test/test_dask_distributed.py rename to tests/python_package_test/test_dask.py index 342c17b6e32a..d64bc6924049 100644 --- a/tests/python_package_test/test_dask_distributed.py +++ b/tests/python_package_test/test_dask.py @@ -16,9 +16,8 @@ from distributed.utils_test import client, cluster_fixture, loop, gen_cluster # noqa from sklearn.datasets import make_blobs, make_regression -from sklearn.metrics import confusion_matrix -import lightgbm.dask_distributed as dlgbm +import lightgbm.dask as dlgbm data_output = ['array', 'scipy_csr_matrix', 'dataframe'] data_centers = [[[-4, -4], [4, 4]], [[-4, -4], [4, 4], [-4, 4]]] @@ -47,7 +46,7 @@ def _create_data(objective, n_samples=100, centers=2, output='array', chunk_size else: raise ValueError(objective) rnd = np.random.RandomState(42) - w = rnd.rand(X.shape[0]) * 0.01 + w = rnd.random(X.shape[0]) * 0.01 if output == 'array': dX = da.from_array(X, (chunk_size, X.shape[1])) @@ -63,6 +62,8 @@ def _create_data(objective, n_samples=100, centers=2, output='array', chunk_size dX = da.from_array(X, chunks=(chunk_size, X.shape[1])).map_blocks(scipy.sparse.csr_matrix) dy = da.from_array(y, chunks=chunk_size) dw = da.from_array(w, chunk_size) + else: + raise ValueError("Unknown output type %s" % output) return X, y, w, dX, dy, dw @@ -72,9 +73,9 @@ def _create_data(objective, n_samples=100, centers=2, output='array', chunk_size def test_classifier(output, centers, client, listen_port): # noqa X, y, w, dX, dy, dw = _create_data('classification', output=output, centers=centers) - a = dlgbm.LGBMClassifier(time_out=5, local_listen_port=listen_port) - a = a.fit(dX, dy, sample_weight=dw, client=client) - p1 = a.predict(dX, client=client) + a = dlgbm.LGBMClassifier(client=client, time_out=5, local_listen_port=listen_port) + a = a.fit(dX, dy, sample_weight=dw) + p1 = a.predict(dX) s1 = accuracy_score(dy, p1) p1 = p1.compute() @@ -82,11 +83,8 @@ def test_classifier(output, centers, client, listen_port): # noqa b.fit(X, y, sample_weight=w) p2 = b.predict(X) s2 = b.score(X, y) - print(confusion_matrix(y, p1)) - print(confusion_matrix(y, p2)) assert_eq(s1, s2) - print(s1) assert_eq(p1, p2) assert_eq(y, p1) @@ -98,9 +96,9 @@ def test_classifier(output, centers, client, listen_port): # noqa def test_classifier_proba(output, centers, client, listen_port): # noqa X, y, w, dX, dy, dw = _create_data('classification', output=output, centers=centers) - a = dlgbm.LGBMClassifier(time_out=5, local_listen_port=listen_port) - a = a.fit(dX, dy, sample_weight=dw, client=client) - p1 = a.predict_proba(dX, client=client) + a = dlgbm.LGBMClassifier(client=client, time_out=5, local_listen_port=listen_port) + a = a.fit(dX, dy, sample_weight=dw) + p1 = a.predict_proba(dX) p1 = p1.compute() b = lightgbm.LGBMClassifier() @@ -113,8 +111,8 @@ def test_classifier_proba(output, centers, client, listen_port): # noqa def test_classifier_local_predict(client, listen_port): # noqa X, y, w, dX, dy, dw = _create_data('classification', output='array') - a = dlgbm.LGBMClassifier(time_out=5, local_listen_port=listen_port) - a = a.fit(dX, dy, sample_weight=dw, client=client) + a = dlgbm.LGBMClassifier(client=client, time_out=5, local_listen_port=listen_port) + a = a.fit(dX, dy, sample_weight=dw) p1 = a.to_local().predict(dX) b = lightgbm.LGBMClassifier() @@ -130,9 +128,9 @@ def test_classifier_local_predict(client, listen_port): # noqa def test_regressor(output, client, listen_port): # noqa X, y, w, dX, dy, dw = _create_data('regression', output=output) - a = dlgbm.LGBMRegressor(time_out=5, local_listen_port=listen_port, seed=42) - a = a.fit(dX, dy, client=client, sample_weight=dw) - p1 = a.predict(dX, client=client) + a = dlgbm.LGBMRegressor(client=client, time_out=5, local_listen_port=listen_port, seed=42) + a = a.fit(dX, dy, sample_weight=dw) + p1 = a.predict(dX) if output != 'dataframe': s1 = r2_score(dy, p1) p1 = p1.compute() @@ -156,9 +154,9 @@ def test_regressor(output, client, listen_port): # noqa def test_regressor_quantile(output, client, listen_port, alpha): # noqa X, y, w, dX, dy, dw = _create_data('regression', output=output) - a = dlgbm.LGBMRegressor(local_listen_port=listen_port, seed=42, objective='quantile', alpha=alpha) + a = dlgbm.LGBMRegressor(client=client, local_listen_port=listen_port, seed=42, objective='quantile', alpha=alpha) a = a.fit(dX, dy, client=client, sample_weight=dw) - p1 = a.predict(dX, client=client).compute() + p1 = a.predict(dX).compute() q1 = np.count_nonzero(y < p1) / y.shape[0] b = lightgbm.LGBMRegressor(seed=42, objective='quantile', alpha=alpha) @@ -174,14 +172,13 @@ def test_regressor_quantile(output, client, listen_port, alpha): # noqa def test_regressor_local_predict(client, listen_port): # noqa X, y, w, dX, dy, dw = _create_data('regression', output='array') - a = dlgbm.LGBMRegressor(local_listen_port=listen_port, seed=42) - a = a.fit(dX, dy, sample_weight=dw, client=client) + a = dlgbm.LGBMRegressor(client=client, local_listen_port=listen_port, seed=42) + a = a.fit(dX, dy, sample_weight=dw) p1 = a.predict(dX) p2 = a.to_local().predict(X) s1 = r2_score(dy, p1) p1 = p1.compute() s2 = a.to_local().score(X, y) - print(s1) # Predictions and scores should be the same assert_eq(p1, p2) @@ -195,7 +192,7 @@ def test_build_network_params(): 'tcp://192.168.0.3:34347' ] - params = dlgbm.build_network_params(workers_ips, 'tcp://192.168.0.2:34346', 12400, 120) + params = dlgbm._build_network_params(workers_ips, 'tcp://192.168.0.2:34346', 12400, 120) exp_params = { 'machines': '192.168.0.1:12400,192.168.0.2:12401,192.168.0.3:12402', 'local_listen_port': 12401, @@ -213,5 +210,5 @@ def f(part): df = dd.demo.make_timeseries() df = df.map_partitions(f, meta=df._meta) with pytest.raises(Exception) as info: - yield dlgbm.train(c, df, df.x, params={}, model_factory=lightgbm.LGBMClassifier) + yield dlgbm._train(c, df, df.x, params={}, model_factory=lightgbm.LGBMClassifier) assert 'foo' in str(info.value) From d580dffb37a4427740ccb5232a9e36fadacc1728 Mon Sep 17 00:00:00 2001 From: Jan Stiborek Date: Sun, 15 Nov 2020 16:51:06 +0100 Subject: [PATCH 12/19] fixed refactoring --- python-package/lightgbm/dask.py | 87 +++++++++++++++----------- python-package/setup.py | 3 +- tests/python_package_test/test_dask.py | 32 +++++----- 3 files changed, 65 insertions(+), 57 deletions(-) diff --git a/python-package/lightgbm/dask.py b/python-package/lightgbm/dask.py index c8f14c7f7c0f..d0f793ec6203 100644 --- a/python-package/lightgbm/dask.py +++ b/python-package/lightgbm/dask.py @@ -18,7 +18,6 @@ import lightgbm from .basic import _LIB, _safe_call from .sklearn import LGBMClassifier as LocalLGBMClassifier, LGBMRegressor as LocalLGBMRegressor -from .compat import _LGBMModelBase try: import scipy.sparse as ss @@ -184,7 +183,6 @@ def _predict(model, data, proba=False, dtype=np.float32, **kwargs): Parameters ---------- - client: dask.Client - client model : data : dask array of shape = [n_samples, n_features] Input feature matrix. @@ -206,34 +204,7 @@ def _predict(model, data, proba=False, dtype=np.float32, **kwargs): raise TypeError('Data must be either Dask array or dataframe. Got %s.' % str(type(data))) -class _LGBMModel(_LGBMModelBase): - - def __init__(self, model_factory, client=None) -> None: - if client is None: - client = default_client() - self.client = client - self.model_factory = model_factory - - def fit(self, X, y=None, sample_weight=None, **kwargs): - """Docstring is inherited from the appropriate model_factory.""" - params = self.get_params(True) - model = _train(self.client, X, y, params, self.model_factory, sample_weight, **kwargs) - - self.set_params(**model.get_params()) - self._copy_extra_params(model, self) - - return self - - def to_local(self): - """Create regular version of lightgbm.LGBMRegressor from the distributed version. - - Returns - ------- - model : lightgbm.LGBMRegressor - """ - model = self.model_factory(**self.get_params()) - self._copy_extra_params(self, model) - return model +class _LGBMModel: @staticmethod def _copy_extra_params(source, dest): @@ -247,11 +218,20 @@ def _copy_extra_params(source, dest): class LGBMClassifier(_LGBMModel, LocalLGBMClassifier): """Distributed version of lightgbm.LGBMClassifier.""" - def __init__(self, client=None, **kwargs) -> None: - super().__init__(LocalLGBMClassifier, client) - super(_LGBMModel, self).__init__(**kwargs) + def fit(self, X, y=None, sample_weight=None, client=None, **kwargs): + """Docstring is inherited from the LGBMModel.""" + if client is None: + client = default_client() - _LGBMModel.fit.__doc__ = LocalLGBMClassifier.fit.__doc__ + model_factory = LocalLGBMClassifier + params = self.get_params(True) + model = _train(client, X, y, params, model_factory, sample_weight, **kwargs) + + self.set_params(**model.get_params()) + self._copy_extra_params(model, self) + + return self + fit.__doc__ = LocalLGBMClassifier.fit.__doc__ def predict(self, X, **kwargs): """Docstring is inherited from the lightgbm.LGBMClassifier.predict.""" @@ -263,17 +243,48 @@ def predict_proba(self, X, **kwargs): return _predict(self.to_local(), X, proba=True, **kwargs) predict_proba.__doc__ = LocalLGBMClassifier.predict_proba.__doc__ + def to_local(self): + """Create regular version of lightgbm.LGBMClassifier from the distributed version. + + Returns + ------- + model : lightgbm.LGBMClassifier + """ + model = LocalLGBMClassifier(**self.get_params()) + self._copy_extra_params(self, model) + return model + class LGBMRegressor(_LGBMModel, LocalLGBMRegressor): """Docstring is inherited from the lightgbm.LGBMRegressor.""" - def __init__(self, client=None, **kwargs) -> None: - super().__init__(LocalLGBMRegressor, client) - super(LocalLGBMRegressor, self).__init__(**kwargs) + def fit(self, X, y=None, sample_weight=None, client=None, **kwargs): + """Docstring is inherited from the lightgbm.LGBMRegressor.fit.""" + if client is None: + client = default_client() - _LGBMModel.fit.__doc__ = LocalLGBMRegressor.fit.__doc__ + model_factory = LocalLGBMRegressor + params = self.get_params(True) + model = _train(client, X, y, params, model_factory, sample_weight, **kwargs) + + self.set_params(**model.get_params()) + self._copy_extra_params(model, self) + + return self + fit.__doc__ = LocalLGBMRegressor.fit.__doc__ def predict(self, X, **kwargs): """Docstring is inherited from the lightgbm.LGBMRegressor.predict.""" return _predict(self.to_local(), X, **kwargs) predict.__doc__ = LocalLGBMRegressor.predict.__doc__ + + def to_local(self): + """Create regular version of lightgbm.LGBMRegressor from the distributed version. + + Returns + ------- + model : lightgbm.LGBMRegressor + """ + model = LocalLGBMRegressor(**self.get_params()) + self._copy_extra_params(self, model) + return model diff --git a/python-package/setup.py b/python-package/setup.py index 4f41626c7d17..d9199df53e02 100644 --- a/python-package/setup.py +++ b/python-package/setup.py @@ -345,7 +345,8 @@ def run(self): 'dask[array]>=2.0.0', 'dask[dataframe]>=2.0.0' 'dask[distributed]>=2.0.0', - 'pandas' + 'pandas', + 'toolz' ], }, maintainer='Guolin Ke', diff --git a/tests/python_package_test/test_dask.py b/tests/python_package_test/test_dask.py index d64bc6924049..242a77716a01 100644 --- a/tests/python_package_test/test_dask.py +++ b/tests/python_package_test/test_dask.py @@ -3,20 +3,16 @@ import dask.array as da import dask.dataframe as dd -import lightgbm import numpy as np import pandas as pd import pytest import scipy.sparse from dask.array.utils import assert_eq -try: - from dask_ml.metrics import accuracy_score, r2_score -except ImportError: - from sklearn.metrics import accuracy_score, r2_score - -from distributed.utils_test import client, cluster_fixture, loop, gen_cluster # noqa +from dask_ml.metrics import accuracy_score, r2_score +from distributed.utils_test import client, cluster_fixture, gen_cluster, loop # noqa from sklearn.datasets import make_blobs, make_regression +import lightgbm import lightgbm.dask as dlgbm data_output = ['array', 'scipy_csr_matrix', 'dataframe'] @@ -73,8 +69,8 @@ def _create_data(objective, n_samples=100, centers=2, output='array', chunk_size def test_classifier(output, centers, client, listen_port): # noqa X, y, w, dX, dy, dw = _create_data('classification', output=output, centers=centers) - a = dlgbm.LGBMClassifier(client=client, time_out=5, local_listen_port=listen_port) - a = a.fit(dX, dy, sample_weight=dw) + a = dlgbm.LGBMClassifier(time_out=5, local_listen_port=listen_port) + a = a.fit(dX, dy, sample_weight=dw, client=client) p1 = a.predict(dX) s1 = accuracy_score(dy, p1) p1 = p1.compute() @@ -96,8 +92,8 @@ def test_classifier(output, centers, client, listen_port): # noqa def test_classifier_proba(output, centers, client, listen_port): # noqa X, y, w, dX, dy, dw = _create_data('classification', output=output, centers=centers) - a = dlgbm.LGBMClassifier(client=client, time_out=5, local_listen_port=listen_port) - a = a.fit(dX, dy, sample_weight=dw) + a = dlgbm.LGBMClassifier(time_out=5, local_listen_port=listen_port) + a = a.fit(dX, dy, sample_weight=dw, client=client) p1 = a.predict_proba(dX) p1 = p1.compute() @@ -111,8 +107,8 @@ def test_classifier_proba(output, centers, client, listen_port): # noqa def test_classifier_local_predict(client, listen_port): # noqa X, y, w, dX, dy, dw = _create_data('classification', output='array') - a = dlgbm.LGBMClassifier(client=client, time_out=5, local_listen_port=listen_port) - a = a.fit(dX, dy, sample_weight=dw) + a = dlgbm.LGBMClassifier(time_out=5, local_listen_port=listen_port) + a = a.fit(dX, dy, sample_weight=dw, client=client) p1 = a.to_local().predict(dX) b = lightgbm.LGBMClassifier() @@ -128,8 +124,8 @@ def test_classifier_local_predict(client, listen_port): # noqa def test_regressor(output, client, listen_port): # noqa X, y, w, dX, dy, dw = _create_data('regression', output=output) - a = dlgbm.LGBMRegressor(client=client, time_out=5, local_listen_port=listen_port, seed=42) - a = a.fit(dX, dy, sample_weight=dw) + a = dlgbm.LGBMRegressor(time_out=5, local_listen_port=listen_port, seed=42) + a = a.fit(dX, dy, client=client, sample_weight=dw) p1 = a.predict(dX) if output != 'dataframe': s1 = r2_score(dy, p1) @@ -154,7 +150,7 @@ def test_regressor(output, client, listen_port): # noqa def test_regressor_quantile(output, client, listen_port, alpha): # noqa X, y, w, dX, dy, dw = _create_data('regression', output=output) - a = dlgbm.LGBMRegressor(client=client, local_listen_port=listen_port, seed=42, objective='quantile', alpha=alpha) + a = dlgbm.LGBMRegressor(local_listen_port=listen_port, seed=42, objective='quantile', alpha=alpha) a = a.fit(dX, dy, client=client, sample_weight=dw) p1 = a.predict(dX).compute() q1 = np.count_nonzero(y < p1) / y.shape[0] @@ -172,8 +168,8 @@ def test_regressor_quantile(output, client, listen_port, alpha): # noqa def test_regressor_local_predict(client, listen_port): # noqa X, y, w, dX, dy, dw = _create_data('regression', output='array') - a = dlgbm.LGBMRegressor(client=client, local_listen_port=listen_port, seed=42) - a = a.fit(dX, dy, sample_weight=dw) + a = dlgbm.LGBMRegressor(local_listen_port=listen_port, seed=42) + a = a.fit(dX, dy, sample_weight=dw, client=client) p1 = a.predict(dX) p2 = a.to_local().predict(X) s1 = r2_score(dy, p1) From 223c0d8ad62db2a31a7772713c6f5d81e834df7e Mon Sep 17 00:00:00 2001 From: Jan Stiborek Date: Sun, 15 Nov 2020 16:59:30 +0100 Subject: [PATCH 13/19] code deduplication - fit method moved into mixin class --- python-package/lightgbm/dask.py | 52 +++++++++++++++------------------ 1 file changed, 23 insertions(+), 29 deletions(-) diff --git a/python-package/lightgbm/dask.py b/python-package/lightgbm/dask.py index d0f793ec6203..7d652feb21b6 100644 --- a/python-package/lightgbm/dask.py +++ b/python-package/lightgbm/dask.py @@ -40,7 +40,7 @@ def _build_network_params(worker_addresses, local_worker_ip, local_listen_port, worker_addresses : iterable of str - collection of worker addresses in `://:port` format local_worker_ip : str local_listen_port : int - listen_time_out : int + time_out : int Returns ------- @@ -206,6 +206,24 @@ def _predict(model, data, proba=False, dtype=np.float32, **kwargs): class _LGBMModel: + def _fit(self, model_factory, X, y=None, sample_weight=None, client=None, **kwargs): + """Docstring is inherited from the LGBMModel.""" + if client is None: + client = default_client() + + params = self.get_params(True) + model = _train(client, X, y, params, model_factory, sample_weight, **kwargs) + + self.set_params(**model.get_params()) + self._copy_extra_params(model, self) + + return self + + def _to_local(self, model_factory): + model = model_factory(**self.get_params()) + self._copy_extra_params(self, model) + return model + @staticmethod def _copy_extra_params(source, dest): params = source.get_params() @@ -220,17 +238,7 @@ class LGBMClassifier(_LGBMModel, LocalLGBMClassifier): def fit(self, X, y=None, sample_weight=None, client=None, **kwargs): """Docstring is inherited from the LGBMModel.""" - if client is None: - client = default_client() - - model_factory = LocalLGBMClassifier - params = self.get_params(True) - model = _train(client, X, y, params, model_factory, sample_weight, **kwargs) - - self.set_params(**model.get_params()) - self._copy_extra_params(model, self) - - return self + return self._fit(LocalLGBMClassifier, X, y, sample_weight, client, **kwargs) fit.__doc__ = LocalLGBMClassifier.fit.__doc__ def predict(self, X, **kwargs): @@ -250,9 +258,7 @@ def to_local(self): ------- model : lightgbm.LGBMClassifier """ - model = LocalLGBMClassifier(**self.get_params()) - self._copy_extra_params(self, model) - return model + return self._to_local(LocalLGBMClassifier) class LGBMRegressor(_LGBMModel, LocalLGBMRegressor): @@ -260,17 +266,7 @@ class LGBMRegressor(_LGBMModel, LocalLGBMRegressor): def fit(self, X, y=None, sample_weight=None, client=None, **kwargs): """Docstring is inherited from the lightgbm.LGBMRegressor.fit.""" - if client is None: - client = default_client() - - model_factory = LocalLGBMRegressor - params = self.get_params(True) - model = _train(client, X, y, params, model_factory, sample_weight, **kwargs) - - self.set_params(**model.get_params()) - self._copy_extra_params(model, self) - - return self + return self._fit(LocalLGBMRegressor, X, y, sample_weight, client, **kwargs) fit.__doc__ = LocalLGBMRegressor.fit.__doc__ def predict(self, X, **kwargs): @@ -285,6 +281,4 @@ def to_local(self): ------- model : lightgbm.LGBMRegressor """ - model = LocalLGBMRegressor(**self.get_params()) - self._copy_extra_params(self, model) - return model + return self._to_local(LocalLGBMRegressor) From a3282d5d48dc3769a451b7123c28111938ef701f Mon Sep 17 00:00:00 2001 From: Jan Stiborek Date: Mon, 16 Nov 2020 19:01:00 +0100 Subject: [PATCH 14/19] fixed CODEOWNERS --- .github/CODEOWNERS | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/CODEOWNERS b/.github/CODEOWNERS index a492c6532ff6..7812ccdc3bff 100644 --- a/.github/CODEOWNERS +++ b/.github/CODEOWNERS @@ -31,8 +31,8 @@ R-package/ @Laurae2 @jameslamb python-package/ @StrikerRUS @chivee @wxchan @henry0312 # Dask integration -python-package/lightgbm/dask_distributed.py @jameslamb -tests/python_package_test/test_dask_distributed.py @jameslamb +python-package/lightgbm/dask.py @jameslamb +tests/python_package_test/test_dask.py @jameslamb # helpers helpers/ @StrikerRUS @guolinke From b11df27fc51139729670e55933075f90cc3068ab Mon Sep 17 00:00:00 2001 From: Jan Stiborek Date: Mon, 16 Nov 2020 19:01:39 +0100 Subject: [PATCH 15/19] removed unnecessary import --- python-package/lightgbm/dask.py | 1 - 1 file changed, 1 deletion(-) diff --git a/python-package/lightgbm/dask.py b/python-package/lightgbm/dask.py index 7d652feb21b6..f5ee52bee2e1 100644 --- a/python-package/lightgbm/dask.py +++ b/python-package/lightgbm/dask.py @@ -15,7 +15,6 @@ from dask.distributed import default_client, get_worker, wait from toolz import assoc, first -import lightgbm from .basic import _LIB, _safe_call from .sklearn import LGBMClassifier as LocalLGBMClassifier, LGBMRegressor as LocalLGBMRegressor From 07bd8a6df5f01ff2787420088d53d0ac04f313bb Mon Sep 17 00:00:00 2001 From: Jan Stiborek Date: Mon, 16 Nov 2020 19:46:52 +0100 Subject: [PATCH 16/19] skip the module execution on python < 3.6 and on platform different than linux. --- tests/python_package_test/test_dask.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/tests/python_package_test/test_dask.py b/tests/python_package_test/test_dask.py index 242a77716a01..875a8b056552 100644 --- a/tests/python_package_test/test_dask.py +++ b/tests/python_package_test/test_dask.py @@ -1,11 +1,16 @@ import os import sys +import pytest +if not sys.platform.startswith("linux"): + pytest.skip("lightgbm.dask is currently supported in Linux environments", allow_module_level=True) +if sys.version_info < (3, 6): + pytest.skip("Only python>=3.6 is supported", allow_module_level=True) + import dask.array as da import dask.dataframe as dd import numpy as np import pandas as pd -import pytest import scipy.sparse from dask.array.utils import assert_eq from dask_ml.metrics import accuracy_score, r2_score @@ -19,8 +24,6 @@ data_centers = [[[-4, -4], [4, 4]], [[-4, -4], [4, 4], [-4, 4]]] pytestmark = [ - pytest.mark.skipif(sys.platform != "linux", reason="Only linux is currently supported"), - pytest.mark.skipif(sys.version_info < (3, 6), reason="Only python>=3.6 is supported"), pytest.mark.skipif(os.getenv("TASK", "") == "mpi", reason="Fails to run with MPI interface") ] From 8602c7825ecde7b8326dcb01e231ad8478588597 Mon Sep 17 00:00:00 2001 From: Jan Stiborek Date: Wed, 9 Dec 2020 20:53:02 +0100 Subject: [PATCH 17/19] removed skip for python < 3.6 --- tests/python_package_test/test_dask.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/tests/python_package_test/test_dask.py b/tests/python_package_test/test_dask.py index 875a8b056552..5a445fa91163 100644 --- a/tests/python_package_test/test_dask.py +++ b/tests/python_package_test/test_dask.py @@ -4,8 +4,6 @@ import pytest if not sys.platform.startswith("linux"): pytest.skip("lightgbm.dask is currently supported in Linux environments", allow_module_level=True) -if sys.version_info < (3, 6): - pytest.skip("Only python>=3.6 is supported", allow_module_level=True) import dask.array as da import dask.dataframe as dd From 5270d634529231da82b875c0f0e091e7594dccb8 Mon Sep 17 00:00:00 2001 From: Jan Stiborek Date: Sun, 13 Dec 2020 09:07:27 +0100 Subject: [PATCH 18/19] review comments --- .ci/test.sh | 2 +- python-package/lightgbm/dask.py | 17 +++++++---------- python-package/setup.py | 1 - tests/python_package_test/test_dask.py | 23 ++++++++++++----------- 4 files changed, 20 insertions(+), 23 deletions(-) diff --git a/.ci/test.sh b/.ci/test.sh index 0c28943d2458..cf3f51e3899a 100755 --- a/.ci/test.sh +++ b/.ci/test.sh @@ -70,7 +70,7 @@ if [[ $TASK == "if-else" ]]; then exit 0 fi -conda install -q -y -n $CONDA_ENV joblib matplotlib numpy pandas psutil pytest python-graphviz scikit-learn scipy dask distributed dask-ml +conda install -q -y -n $CONDA_ENV dask dask-ml distributed joblib matplotlib numpy pandas psutil pytest python-graphviz scikit-learn scipy if [[ $OS_NAME == "macos" ]] && [[ $COMPILER == "clang" ]]; then # fix "OMP: Error #15: Initializing libiomp5.dylib, but found libomp.dylib already initialized." (OpenMP library conflict due to conda's MKL) diff --git a/python-package/lightgbm/dask.py b/python-package/lightgbm/dask.py index f5ee52bee2e1..8f5ac1ee2ffb 100644 --- a/python-package/lightgbm/dask.py +++ b/python-package/lightgbm/dask.py @@ -1,3 +1,4 @@ +# coding: utf-8 """Distributed training with LightGBM and Dask.distributed. This module enables you to perform distributed training with LightGBM on Dask.Array and Dask.DataFrame collections. @@ -13,15 +14,11 @@ from dask import dataframe as dd from dask import delayed from dask.distributed import default_client, get_worker, wait -from toolz import assoc, first from .basic import _LIB, _safe_call from .sklearn import LGBMClassifier as LocalLGBMClassifier, LGBMRegressor as LocalLGBMRegressor -try: - import scipy.sparse as ss -except ImportError: - ss = False +import scipy.sparse as ss logger = logging.getLogger(__name__) @@ -32,7 +29,7 @@ def _parse_host_port(address): def _build_network_params(worker_addresses, local_worker_ip, local_listen_port, time_out): - """Build network parameters suiltable for LightGBM C backend. + """Build network parameters suitable for LightGBM C backend. Parameters ---------- @@ -60,7 +57,7 @@ def _concat(seq): return np.concatenate(seq, axis=0) elif isinstance(seq[0], (pd.DataFrame, pd.Series)): return pd.concat(seq, axis=0) - elif ss and isinstance(seq[0], ss.spmatrix): + elif isinstance(seq[0], ss.spmatrix): return ss.vstack(seq, format='csr') else: raise TypeError('Data must be one of: numpy arrays, pandas dataframes, sparse matrices (from scipy). Got %s.' % str(type(seq[0]))) @@ -131,9 +128,9 @@ def _train(client, data, label, params, model_factory, weight=None, **kwargs): who_has = client.who_has(parts) worker_map = defaultdict(list) for key, workers in who_has.items(): - worker_map[first(workers)].append(key_to_part_dict[key]) + worker_map[next(iter(workers))].append(key_to_part_dict[key]) - master_worker = first(worker_map) + master_worker = next(iter(worker_map)) worker_ncores = client.ncores() if 'tree_learner' not in params or params['tree_learner'].lower() not in {'data', 'feature', 'voting'}: @@ -144,7 +141,7 @@ def _train(client, data, label, params, model_factory, weight=None, **kwargs): # Tell each worker to train on the parts that it has locally futures_classifiers = [client.submit(_train_part, model_factory=model_factory, - params=assoc(params, 'num_threads', worker_ncores[worker]), + params={**params, 'num_threads': worker_ncores[worker]}, list_of_parts=list_of_parts, worker_addresses=list(worker_map.keys()), local_listen_port=params.get('local_listen_port', 12400), diff --git a/python-package/setup.py b/python-package/setup.py index d9199df53e02..02ba32f3549c 100644 --- a/python-package/setup.py +++ b/python-package/setup.py @@ -346,7 +346,6 @@ def run(self): 'dask[dataframe]>=2.0.0' 'dask[distributed]>=2.0.0', 'pandas', - 'toolz' ], }, maintainer='Guolin Ke', diff --git a/tests/python_package_test/test_dask.py b/tests/python_package_test/test_dask.py index 5a445fa91163..2450e6ea5e5f 100644 --- a/tests/python_package_test/test_dask.py +++ b/tests/python_package_test/test_dask.py @@ -1,3 +1,4 @@ +# coding: utf-8 import os import sys @@ -70,16 +71,16 @@ def _create_data(objective, n_samples=100, centers=2, output='array', chunk_size def test_classifier(output, centers, client, listen_port): # noqa X, y, w, dX, dy, dw = _create_data('classification', output=output, centers=centers) - a = dlgbm.LGBMClassifier(time_out=5, local_listen_port=listen_port) - a = a.fit(dX, dy, sample_weight=dw, client=client) - p1 = a.predict(dX) + classifier_a = dlgbm.LGBMClassifier(time_out=5, local_listen_port=listen_port) + classifier_a = classifier_a.fit(dX, dy, sample_weight=dw, client=client) + p1 = classifier_a.predict(dX) s1 = accuracy_score(dy, p1) p1 = p1.compute() - b = lightgbm.LGBMClassifier() - b.fit(X, y, sample_weight=w) - p2 = b.predict(X) - s2 = b.score(X, y) + classifier_b = lightgbm.LGBMClassifier() + classifier_b.fit(X, y, sample_weight=w) + p2 = classifier_b.predict(X) + s2 = classifier_b.score(X, y) assert_eq(s1, s2) @@ -162,11 +163,11 @@ def test_regressor_quantile(output, client, listen_port, alpha): # noqa q2 = np.count_nonzero(y < p2) / y.shape[0] # Quantiles should be right - np.isclose(q1, alpha, atol=.1) - np.isclose(q2, alpha, atol=.1) + np.testing.assert_allclose(q1, alpha, atol=0.2) + np.testing.assert_allclose(q2, alpha, atol=0.2) -def test_regressor_local_predict(client, listen_port): # noqa +def test_regressor_local_predict(client, listen_port): X, y, w, dX, dy, dw = _create_data('regression', output='array') a = dlgbm.LGBMRegressor(local_listen_port=listen_port, seed=42) @@ -179,7 +180,7 @@ def test_regressor_local_predict(client, listen_port): # noqa # Predictions and scores should be the same assert_eq(p1, p2) - np.isclose(s1, s2) + assert_eq(s1, s2) def test_build_network_params(): From 1720d85cdd593e52a9d24f0044d8621e2658fa3c Mon Sep 17 00:00:00 2001 From: Jan Stiborek Date: Tue, 22 Dec 2020 09:41:50 +0100 Subject: [PATCH 19/19] removed noqa, renamed API classes, renamed local variables --- python-package/lightgbm/dask.py | 24 +++---- tests/python_package_test/test_dask.py | 96 +++++++++++++------------- 2 files changed, 60 insertions(+), 60 deletions(-) diff --git a/python-package/lightgbm/dask.py b/python-package/lightgbm/dask.py index 8f5ac1ee2ffb..66ab83591cdb 100644 --- a/python-package/lightgbm/dask.py +++ b/python-package/lightgbm/dask.py @@ -16,7 +16,7 @@ from dask.distributed import default_client, get_worker, wait from .basic import _LIB, _safe_call -from .sklearn import LGBMClassifier as LocalLGBMClassifier, LGBMRegressor as LocalLGBMRegressor +from .sklearn import LGBMClassifier, LGBMRegressor import scipy.sparse as ss @@ -229,23 +229,23 @@ def _copy_extra_params(source, dest): setattr(dest, name, attributes[name]) -class LGBMClassifier(_LGBMModel, LocalLGBMClassifier): +class DaskLGBMClassifier(_LGBMModel, LGBMClassifier): """Distributed version of lightgbm.LGBMClassifier.""" def fit(self, X, y=None, sample_weight=None, client=None, **kwargs): """Docstring is inherited from the LGBMModel.""" - return self._fit(LocalLGBMClassifier, X, y, sample_weight, client, **kwargs) - fit.__doc__ = LocalLGBMClassifier.fit.__doc__ + return self._fit(LGBMClassifier, X, y, sample_weight, client, **kwargs) + fit.__doc__ = LGBMClassifier.fit.__doc__ def predict(self, X, **kwargs): """Docstring is inherited from the lightgbm.LGBMClassifier.predict.""" return _predict(self.to_local(), X, dtype=self.classes_.dtype, **kwargs) - predict.__doc__ = LocalLGBMClassifier.predict.__doc__ + predict.__doc__ = LGBMClassifier.predict.__doc__ def predict_proba(self, X, **kwargs): """Docstring is inherited from the lightgbm.LGBMClassifier.predict_proba.""" return _predict(self.to_local(), X, proba=True, **kwargs) - predict_proba.__doc__ = LocalLGBMClassifier.predict_proba.__doc__ + predict_proba.__doc__ = LGBMClassifier.predict_proba.__doc__ def to_local(self): """Create regular version of lightgbm.LGBMClassifier from the distributed version. @@ -254,21 +254,21 @@ def to_local(self): ------- model : lightgbm.LGBMClassifier """ - return self._to_local(LocalLGBMClassifier) + return self._to_local(LGBMClassifier) -class LGBMRegressor(_LGBMModel, LocalLGBMRegressor): +class DaskLGBMRegressor(_LGBMModel, LGBMRegressor): """Docstring is inherited from the lightgbm.LGBMRegressor.""" def fit(self, X, y=None, sample_weight=None, client=None, **kwargs): """Docstring is inherited from the lightgbm.LGBMRegressor.fit.""" - return self._fit(LocalLGBMRegressor, X, y, sample_weight, client, **kwargs) - fit.__doc__ = LocalLGBMRegressor.fit.__doc__ + return self._fit(LGBMRegressor, X, y, sample_weight, client, **kwargs) + fit.__doc__ = LGBMRegressor.fit.__doc__ def predict(self, X, **kwargs): """Docstring is inherited from the lightgbm.LGBMRegressor.predict.""" return _predict(self.to_local(), X, **kwargs) - predict.__doc__ = LocalLGBMRegressor.predict.__doc__ + predict.__doc__ = LGBMRegressor.predict.__doc__ def to_local(self): """Create regular version of lightgbm.LGBMRegressor from the distributed version. @@ -277,4 +277,4 @@ def to_local(self): ------- model : lightgbm.LGBMRegressor """ - return self._to_local(LocalLGBMRegressor) + return self._to_local(LGBMRegressor) diff --git a/tests/python_package_test/test_dask.py b/tests/python_package_test/test_dask.py index 2450e6ea5e5f..d512cfcbda63 100644 --- a/tests/python_package_test/test_dask.py +++ b/tests/python_package_test/test_dask.py @@ -13,7 +13,7 @@ import scipy.sparse from dask.array.utils import assert_eq from dask_ml.metrics import accuracy_score, r2_score -from distributed.utils_test import client, cluster_fixture, gen_cluster, loop # noqa +from distributed.utils_test import client, cluster_fixture, gen_cluster, loop from sklearn.datasets import make_blobs, make_regression import lightgbm @@ -44,43 +44,43 @@ def _create_data(objective, n_samples=100, centers=2, output='array', chunk_size else: raise ValueError(objective) rnd = np.random.RandomState(42) - w = rnd.random(X.shape[0]) * 0.01 + weights = rnd.random(X.shape[0]) * 0.01 if output == 'array': dX = da.from_array(X, (chunk_size, X.shape[1])) dy = da.from_array(y, chunk_size) - dw = da.from_array(w, chunk_size) + dw = da.from_array(weights, chunk_size) elif output == 'dataframe': X_df = pd.DataFrame(X, columns=['feature_%d' % i for i in range(X.shape[1])]) y_df = pd.Series(y, name='target') dX = dd.from_pandas(X_df, chunksize=chunk_size) dy = dd.from_pandas(y_df, chunksize=chunk_size) - dw = dd.from_array(w, chunksize=chunk_size) + dw = dd.from_array(weights, chunksize=chunk_size) elif output == 'scipy_csr_matrix': dX = da.from_array(X, chunks=(chunk_size, X.shape[1])).map_blocks(scipy.sparse.csr_matrix) dy = da.from_array(y, chunks=chunk_size) - dw = da.from_array(w, chunk_size) + dw = da.from_array(weights, chunk_size) else: raise ValueError("Unknown output type %s" % output) - return X, y, w, dX, dy, dw + return X, y, weights, dX, dy, dw @pytest.mark.parametrize('output', data_output) @pytest.mark.parametrize('centers', data_centers) -def test_classifier(output, centers, client, listen_port): # noqa +def test_classifier(output, centers, client, listen_port): X, y, w, dX, dy, dw = _create_data('classification', output=output, centers=centers) - classifier_a = dlgbm.LGBMClassifier(time_out=5, local_listen_port=listen_port) - classifier_a = classifier_a.fit(dX, dy, sample_weight=dw, client=client) - p1 = classifier_a.predict(dX) + dask_classifier = dlgbm.DaskLGBMClassifier(time_out=5, local_listen_port=listen_port) + dask_classifier = dask_classifier.fit(dX, dy, sample_weight=dw, client=client) + p1 = dask_classifier.predict(dX) s1 = accuracy_score(dy, p1) p1 = p1.compute() - classifier_b = lightgbm.LGBMClassifier() - classifier_b.fit(X, y, sample_weight=w) - p2 = classifier_b.predict(X) - s2 = classifier_b.score(X, y) + local_classifier = lightgbm.LGBMClassifier() + local_classifier.fit(X, y, sample_weight=w) + p2 = local_classifier.predict(X) + s2 = local_classifier.score(X, y) assert_eq(s1, s2) @@ -91,31 +91,31 @@ def test_classifier(output, centers, client, listen_port): # noqa @pytest.mark.parametrize('output', data_output) @pytest.mark.parametrize('centers', data_centers) -def test_classifier_proba(output, centers, client, listen_port): # noqa +def test_classifier_proba(output, centers, client, listen_port): X, y, w, dX, dy, dw = _create_data('classification', output=output, centers=centers) - a = dlgbm.LGBMClassifier(time_out=5, local_listen_port=listen_port) - a = a.fit(dX, dy, sample_weight=dw, client=client) - p1 = a.predict_proba(dX) + dask_classifier = dlgbm.DaskLGBMClassifier(time_out=5, local_listen_port=listen_port) + dask_classifier = dask_classifier.fit(dX, dy, sample_weight=dw, client=client) + p1 = dask_classifier.predict_proba(dX) p1 = p1.compute() - b = lightgbm.LGBMClassifier() - b.fit(X, y, sample_weight=w) - p2 = b.predict_proba(X) + local_classifier = lightgbm.LGBMClassifier() + local_classifier.fit(X, y, sample_weight=w) + p2 = local_classifier.predict_proba(X) assert_eq(p1, p2, atol=0.3) -def test_classifier_local_predict(client, listen_port): # noqa +def test_classifier_local_predict(client, listen_port): X, y, w, dX, dy, dw = _create_data('classification', output='array') - a = dlgbm.LGBMClassifier(time_out=5, local_listen_port=listen_port) - a = a.fit(dX, dy, sample_weight=dw, client=client) - p1 = a.to_local().predict(dX) + dask_classifier = dlgbm.DaskLGBMClassifier(time_out=5, local_listen_port=listen_port) + dask_classifier = dask_classifier.fit(dX, dy, sample_weight=dw, client=client) + p1 = dask_classifier.to_local().predict(dX) - b = lightgbm.LGBMClassifier() - b.fit(X, y, sample_weight=w) - p2 = b.predict(X) + local_classifier = lightgbm.LGBMClassifier() + local_classifier.fit(X, y, sample_weight=w) + p2 = local_classifier.predict(X) assert_eq(p1, p2) assert_eq(y, p1) @@ -123,20 +123,20 @@ def test_classifier_local_predict(client, listen_port): # noqa @pytest.mark.parametrize('output', data_output) -def test_regressor(output, client, listen_port): # noqa +def test_regressor(output, client, listen_port): X, y, w, dX, dy, dw = _create_data('regression', output=output) - a = dlgbm.LGBMRegressor(time_out=5, local_listen_port=listen_port, seed=42) - a = a.fit(dX, dy, client=client, sample_weight=dw) - p1 = a.predict(dX) + dask_regressor = dlgbm.DaskLGBMRegressor(time_out=5, local_listen_port=listen_port, seed=42) + dask_regressor = dask_regressor.fit(dX, dy, client=client, sample_weight=dw) + p1 = dask_regressor.predict(dX) if output != 'dataframe': s1 = r2_score(dy, p1) p1 = p1.compute() - b = lightgbm.LGBMRegressor(seed=42) - b.fit(X, y, sample_weight=w) - s2 = b.score(X, y) - p2 = b.predict(X) + local_regressor = lightgbm.LGBMRegressor(seed=42) + local_regressor.fit(X, y, sample_weight=w) + s2 = local_regressor.score(X, y) + p2 = local_regressor.predict(X) # Scores should be the same if output != 'dataframe': @@ -149,17 +149,17 @@ def test_regressor(output, client, listen_port): # noqa @pytest.mark.parametrize('output', data_output) @pytest.mark.parametrize('alpha', [.1, .5, .9]) -def test_regressor_quantile(output, client, listen_port, alpha): # noqa +def test_regressor_quantile(output, client, listen_port, alpha): X, y, w, dX, dy, dw = _create_data('regression', output=output) - a = dlgbm.LGBMRegressor(local_listen_port=listen_port, seed=42, objective='quantile', alpha=alpha) - a = a.fit(dX, dy, client=client, sample_weight=dw) - p1 = a.predict(dX).compute() + dask_regressor = dlgbm.DaskLGBMRegressor(local_listen_port=listen_port, seed=42, objective='quantile', alpha=alpha) + dask_regressor = dask_regressor.fit(dX, dy, client=client, sample_weight=dw) + p1 = dask_regressor.predict(dX).compute() q1 = np.count_nonzero(y < p1) / y.shape[0] - b = lightgbm.LGBMRegressor(seed=42, objective='quantile', alpha=alpha) - b.fit(X, y, sample_weight=w) - p2 = b.predict(X) + local_regressor = lightgbm.LGBMRegressor(seed=42, objective='quantile', alpha=alpha) + local_regressor.fit(X, y, sample_weight=w) + p2 = local_regressor.predict(X) q2 = np.count_nonzero(y < p2) / y.shape[0] # Quantiles should be right @@ -170,13 +170,13 @@ def test_regressor_quantile(output, client, listen_port, alpha): # noqa def test_regressor_local_predict(client, listen_port): X, y, w, dX, dy, dw = _create_data('regression', output='array') - a = dlgbm.LGBMRegressor(local_listen_port=listen_port, seed=42) - a = a.fit(dX, dy, sample_weight=dw, client=client) - p1 = a.predict(dX) - p2 = a.to_local().predict(X) + dask_regressor = dlgbm.DaskLGBMRegressor(local_listen_port=listen_port, seed=42) + dask_regressor = dask_regressor.fit(dX, dy, sample_weight=dw, client=client) + p1 = dask_regressor.predict(dX) + p2 = dask_regressor.to_local().predict(X) s1 = r2_score(dy, p1) p1 = p1.compute() - s2 = a.to_local().score(X, y) + s2 = dask_regressor.to_local().score(X, y) # Predictions and scores should be the same assert_eq(p1, p2)