diff --git a/.ci/test.sh b/.ci/test.sh index 4cd181790ad7..cf3f51e3899a 100755 --- a/.ci/test.sh +++ b/.ci/test.sh @@ -70,7 +70,7 @@ if [[ $TASK == "if-else" ]]; then exit 0 fi -conda install -q -y -n $CONDA_ENV joblib matplotlib numpy pandas psutil pytest python-graphviz scikit-learn scipy +conda install -q -y -n $CONDA_ENV dask dask-ml distributed joblib matplotlib numpy pandas psutil pytest python-graphviz scikit-learn scipy if [[ $OS_NAME == "macos" ]] && [[ $COMPILER == "clang" ]]; then # fix "OMP: Error #15: Initializing libiomp5.dylib, but found libomp.dylib already initialized." (OpenMP library conflict due to conda's MKL) diff --git a/.github/CODEOWNERS b/.github/CODEOWNERS index 7164a84e8ca5..7812ccdc3bff 100644 --- a/.github/CODEOWNERS +++ b/.github/CODEOWNERS @@ -30,6 +30,10 @@ R-package/ @Laurae2 @jameslamb # Python code python-package/ @StrikerRUS @chivee @wxchan @henry0312 +# Dask integration +python-package/lightgbm/dask.py @jameslamb +tests/python_package_test/test_dask.py @jameslamb + # helpers helpers/ @StrikerRUS @guolinke diff --git a/python-package/lightgbm/dask.py b/python-package/lightgbm/dask.py new file mode 100644 index 000000000000..66ab83591cdb --- /dev/null +++ b/python-package/lightgbm/dask.py @@ -0,0 +1,280 @@ +# coding: utf-8 +"""Distributed training with LightGBM and Dask.distributed. + +This module enables you to perform distributed training with LightGBM on Dask.Array and Dask.DataFrame collections. +It is based on dask-xgboost package. +""" +import logging +from collections import defaultdict +from urllib.parse import urlparse + +import numpy as np +import pandas as pd +from dask import array as da +from dask import dataframe as dd +from dask import delayed +from dask.distributed import default_client, get_worker, wait + +from .basic import _LIB, _safe_call +from .sklearn import LGBMClassifier, LGBMRegressor + +import scipy.sparse as ss + +logger = logging.getLogger(__name__) + + +def _parse_host_port(address): + parsed = urlparse(address) + return parsed.hostname, parsed.port + + +def _build_network_params(worker_addresses, local_worker_ip, local_listen_port, time_out): + """Build network parameters suitable for LightGBM C backend. + + Parameters + ---------- + worker_addresses : iterable of str - collection of worker addresses in `://:port` format + local_worker_ip : str + local_listen_port : int + time_out : int + + Returns + ------- + params: dict + """ + addr_port_map = {addr: (local_listen_port + i) for i, addr in enumerate(worker_addresses)} + params = { + 'machines': ','.join('%s:%d' % (_parse_host_port(addr)[0], port) for addr, port in addr_port_map.items()), + 'local_listen_port': addr_port_map[local_worker_ip], + 'time_out': time_out, + 'num_machines': len(addr_port_map) + } + return params + + +def _concat(seq): + if isinstance(seq[0], np.ndarray): + return np.concatenate(seq, axis=0) + elif isinstance(seq[0], (pd.DataFrame, pd.Series)): + return pd.concat(seq, axis=0) + elif isinstance(seq[0], ss.spmatrix): + return ss.vstack(seq, format='csr') + else: + raise TypeError('Data must be one of: numpy arrays, pandas dataframes, sparse matrices (from scipy). Got %s.' % str(type(seq[0]))) + + +def _train_part(params, model_factory, list_of_parts, worker_addresses, return_model, local_listen_port=12400, + time_out=120, **kwargs): + network_params = _build_network_params(worker_addresses, get_worker().address, local_listen_port, time_out) + params.update(network_params) + + # Concatenate many parts into one + parts = tuple(zip(*list_of_parts)) + data = _concat(parts[0]) + label = _concat(parts[1]) + weight = _concat(parts[2]) if len(parts) == 3 else None + + try: + model = model_factory(**params) + model.fit(data, label, sample_weight=weight, **kwargs) + finally: + _safe_call(_LIB.LGBM_NetworkFree()) + + return model if return_model else None + + +def _split_to_parts(data, is_matrix): + parts = data.to_delayed() + if isinstance(parts, np.ndarray): + assert (parts.shape[1] == 1) if is_matrix else (parts.ndim == 1 or parts.shape[1] == 1) + parts = parts.flatten().tolist() + return parts + + +def _train(client, data, label, params, model_factory, weight=None, **kwargs): + """Inner train routine. + + Parameters + ---------- + client: dask.Client - client + X : dask array of shape = [n_samples, n_features] + Input feature matrix. + y : dask array of shape = [n_samples] + The target values (class labels in classification, real numbers in regression). + params : dict + model_factory : lightgbm.LGBMClassifier or lightgbm.LGBMRegressor class + sample_weight : array-like of shape = [n_samples] or None, optional (default=None) + Weights of training data. + """ + # Split arrays/dataframes into parts. Arrange parts into tuples to enforce co-locality + data_parts = _split_to_parts(data, is_matrix=True) + label_parts = _split_to_parts(label, is_matrix=False) + if weight is None: + parts = list(map(delayed, zip(data_parts, label_parts))) + else: + weight_parts = _split_to_parts(weight, is_matrix=False) + parts = list(map(delayed, zip(data_parts, label_parts, weight_parts))) + + # Start computation in the background + parts = client.compute(parts) + wait(parts) + + for part in parts: + if part.status == 'error': + return part # trigger error locally + + # Find locations of all parts and map them to particular Dask workers + key_to_part_dict = dict([(part.key, part) for part in parts]) + who_has = client.who_has(parts) + worker_map = defaultdict(list) + for key, workers in who_has.items(): + worker_map[next(iter(workers))].append(key_to_part_dict[key]) + + master_worker = next(iter(worker_map)) + worker_ncores = client.ncores() + + if 'tree_learner' not in params or params['tree_learner'].lower() not in {'data', 'feature', 'voting'}: + logger.warning('Parameter tree_learner not set or set to incorrect value ' + '(%s), using "data" as default', params.get("tree_learner", None)) + params['tree_learner'] = 'data' + + # Tell each worker to train on the parts that it has locally + futures_classifiers = [client.submit(_train_part, + model_factory=model_factory, + params={**params, 'num_threads': worker_ncores[worker]}, + list_of_parts=list_of_parts, + worker_addresses=list(worker_map.keys()), + local_listen_port=params.get('local_listen_port', 12400), + time_out=params.get('time_out', 120), + return_model=(worker == master_worker), + **kwargs) + for worker, list_of_parts in worker_map.items()] + + results = client.gather(futures_classifiers) + results = [v for v in results if v] + return results[0] + + +def _predict_part(part, model, proba, **kwargs): + data = part.values if isinstance(part, pd.DataFrame) else part + + if data.shape[0] == 0: + result = np.array([]) + elif proba: + result = model.predict_proba(data, **kwargs) + else: + result = model.predict(data, **kwargs) + + if isinstance(part, pd.DataFrame): + if proba: + result = pd.DataFrame(result, index=part.index) + else: + result = pd.Series(result, index=part.index, name='predictions') + + return result + + +def _predict(model, data, proba=False, dtype=np.float32, **kwargs): + """Inner predict routine. + + Parameters + ---------- + model : + data : dask array of shape = [n_samples, n_features] + Input feature matrix. + proba : bool + Should method return results of predict_proba (proba == True) or predict (proba == False) + dtype : np.dtype + Dtype of the output + kwargs : other parameters passed to predict or predict_proba method + """ + if isinstance(data, dd._Frame): + return data.map_partitions(_predict_part, model=model, proba=proba, **kwargs).values + elif isinstance(data, da.Array): + if proba: + kwargs['chunks'] = (data.chunks[0], (model.n_classes_,)) + else: + kwargs['drop_axis'] = 1 + return data.map_blocks(_predict_part, model=model, proba=proba, dtype=dtype, **kwargs) + else: + raise TypeError('Data must be either Dask array or dataframe. Got %s.' % str(type(data))) + + +class _LGBMModel: + + def _fit(self, model_factory, X, y=None, sample_weight=None, client=None, **kwargs): + """Docstring is inherited from the LGBMModel.""" + if client is None: + client = default_client() + + params = self.get_params(True) + model = _train(client, X, y, params, model_factory, sample_weight, **kwargs) + + self.set_params(**model.get_params()) + self._copy_extra_params(model, self) + + return self + + def _to_local(self, model_factory): + model = model_factory(**self.get_params()) + self._copy_extra_params(self, model) + return model + + @staticmethod + def _copy_extra_params(source, dest): + params = source.get_params() + attributes = source.__dict__ + extra_param_names = set(attributes.keys()).difference(params.keys()) + for name in extra_param_names: + setattr(dest, name, attributes[name]) + + +class DaskLGBMClassifier(_LGBMModel, LGBMClassifier): + """Distributed version of lightgbm.LGBMClassifier.""" + + def fit(self, X, y=None, sample_weight=None, client=None, **kwargs): + """Docstring is inherited from the LGBMModel.""" + return self._fit(LGBMClassifier, X, y, sample_weight, client, **kwargs) + fit.__doc__ = LGBMClassifier.fit.__doc__ + + def predict(self, X, **kwargs): + """Docstring is inherited from the lightgbm.LGBMClassifier.predict.""" + return _predict(self.to_local(), X, dtype=self.classes_.dtype, **kwargs) + predict.__doc__ = LGBMClassifier.predict.__doc__ + + def predict_proba(self, X, **kwargs): + """Docstring is inherited from the lightgbm.LGBMClassifier.predict_proba.""" + return _predict(self.to_local(), X, proba=True, **kwargs) + predict_proba.__doc__ = LGBMClassifier.predict_proba.__doc__ + + def to_local(self): + """Create regular version of lightgbm.LGBMClassifier from the distributed version. + + Returns + ------- + model : lightgbm.LGBMClassifier + """ + return self._to_local(LGBMClassifier) + + +class DaskLGBMRegressor(_LGBMModel, LGBMRegressor): + """Docstring is inherited from the lightgbm.LGBMRegressor.""" + + def fit(self, X, y=None, sample_weight=None, client=None, **kwargs): + """Docstring is inherited from the lightgbm.LGBMRegressor.fit.""" + return self._fit(LGBMRegressor, X, y, sample_weight, client, **kwargs) + fit.__doc__ = LGBMRegressor.fit.__doc__ + + def predict(self, X, **kwargs): + """Docstring is inherited from the lightgbm.LGBMRegressor.predict.""" + return _predict(self.to_local(), X, **kwargs) + predict.__doc__ = LGBMRegressor.predict.__doc__ + + def to_local(self): + """Create regular version of lightgbm.LGBMRegressor from the distributed version. + + Returns + ------- + model : lightgbm.LGBMRegressor + """ + return self._to_local(LGBMRegressor) diff --git a/python-package/setup.py b/python-package/setup.py index a2e8a0cf3560..02ba32f3549c 100644 --- a/python-package/setup.py +++ b/python-package/setup.py @@ -340,6 +340,14 @@ def run(self): 'scipy', 'scikit-learn!=0.22.0' ], + extras_require={ + 'dask': [ + 'dask[array]>=2.0.0', + 'dask[dataframe]>=2.0.0' + 'dask[distributed]>=2.0.0', + 'pandas', + ], + }, maintainer='Guolin Ke', maintainer_email='guolin.ke@microsoft.com', zip_safe=False, diff --git a/tests/python_package_test/test_dask.py b/tests/python_package_test/test_dask.py new file mode 100644 index 000000000000..d512cfcbda63 --- /dev/null +++ b/tests/python_package_test/test_dask.py @@ -0,0 +1,212 @@ +# coding: utf-8 +import os +import sys + +import pytest +if not sys.platform.startswith("linux"): + pytest.skip("lightgbm.dask is currently supported in Linux environments", allow_module_level=True) + +import dask.array as da +import dask.dataframe as dd +import numpy as np +import pandas as pd +import scipy.sparse +from dask.array.utils import assert_eq +from dask_ml.metrics import accuracy_score, r2_score +from distributed.utils_test import client, cluster_fixture, gen_cluster, loop +from sklearn.datasets import make_blobs, make_regression + +import lightgbm +import lightgbm.dask as dlgbm + +data_output = ['array', 'scipy_csr_matrix', 'dataframe'] +data_centers = [[[-4, -4], [4, 4]], [[-4, -4], [4, 4], [-4, 4]]] + +pytestmark = [ + pytest.mark.skipif(os.getenv("TASK", "") == "mpi", reason="Fails to run with MPI interface") +] + + +@pytest.fixture() +def listen_port(): + listen_port.port += 10 + return listen_port.port + + +listen_port.port = 13000 + + +def _create_data(objective, n_samples=100, centers=2, output='array', chunk_size=50): + if objective == 'classification': + X, y = make_blobs(n_samples=n_samples, centers=centers, random_state=42) + elif objective == 'regression': + X, y = make_regression(n_samples=n_samples, random_state=42) + else: + raise ValueError(objective) + rnd = np.random.RandomState(42) + weights = rnd.random(X.shape[0]) * 0.01 + + if output == 'array': + dX = da.from_array(X, (chunk_size, X.shape[1])) + dy = da.from_array(y, chunk_size) + dw = da.from_array(weights, chunk_size) + elif output == 'dataframe': + X_df = pd.DataFrame(X, columns=['feature_%d' % i for i in range(X.shape[1])]) + y_df = pd.Series(y, name='target') + dX = dd.from_pandas(X_df, chunksize=chunk_size) + dy = dd.from_pandas(y_df, chunksize=chunk_size) + dw = dd.from_array(weights, chunksize=chunk_size) + elif output == 'scipy_csr_matrix': + dX = da.from_array(X, chunks=(chunk_size, X.shape[1])).map_blocks(scipy.sparse.csr_matrix) + dy = da.from_array(y, chunks=chunk_size) + dw = da.from_array(weights, chunk_size) + else: + raise ValueError("Unknown output type %s" % output) + + return X, y, weights, dX, dy, dw + + +@pytest.mark.parametrize('output', data_output) +@pytest.mark.parametrize('centers', data_centers) +def test_classifier(output, centers, client, listen_port): + X, y, w, dX, dy, dw = _create_data('classification', output=output, centers=centers) + + dask_classifier = dlgbm.DaskLGBMClassifier(time_out=5, local_listen_port=listen_port) + dask_classifier = dask_classifier.fit(dX, dy, sample_weight=dw, client=client) + p1 = dask_classifier.predict(dX) + s1 = accuracy_score(dy, p1) + p1 = p1.compute() + + local_classifier = lightgbm.LGBMClassifier() + local_classifier.fit(X, y, sample_weight=w) + p2 = local_classifier.predict(X) + s2 = local_classifier.score(X, y) + + assert_eq(s1, s2) + + assert_eq(p1, p2) + assert_eq(y, p1) + assert_eq(y, p2) + + +@pytest.mark.parametrize('output', data_output) +@pytest.mark.parametrize('centers', data_centers) +def test_classifier_proba(output, centers, client, listen_port): + X, y, w, dX, dy, dw = _create_data('classification', output=output, centers=centers) + + dask_classifier = dlgbm.DaskLGBMClassifier(time_out=5, local_listen_port=listen_port) + dask_classifier = dask_classifier.fit(dX, dy, sample_weight=dw, client=client) + p1 = dask_classifier.predict_proba(dX) + p1 = p1.compute() + + local_classifier = lightgbm.LGBMClassifier() + local_classifier.fit(X, y, sample_weight=w) + p2 = local_classifier.predict_proba(X) + + assert_eq(p1, p2, atol=0.3) + + +def test_classifier_local_predict(client, listen_port): + X, y, w, dX, dy, dw = _create_data('classification', output='array') + + dask_classifier = dlgbm.DaskLGBMClassifier(time_out=5, local_listen_port=listen_port) + dask_classifier = dask_classifier.fit(dX, dy, sample_weight=dw, client=client) + p1 = dask_classifier.to_local().predict(dX) + + local_classifier = lightgbm.LGBMClassifier() + local_classifier.fit(X, y, sample_weight=w) + p2 = local_classifier.predict(X) + + assert_eq(p1, p2) + assert_eq(y, p1) + assert_eq(y, p2) + + +@pytest.mark.parametrize('output', data_output) +def test_regressor(output, client, listen_port): + X, y, w, dX, dy, dw = _create_data('regression', output=output) + + dask_regressor = dlgbm.DaskLGBMRegressor(time_out=5, local_listen_port=listen_port, seed=42) + dask_regressor = dask_regressor.fit(dX, dy, client=client, sample_weight=dw) + p1 = dask_regressor.predict(dX) + if output != 'dataframe': + s1 = r2_score(dy, p1) + p1 = p1.compute() + + local_regressor = lightgbm.LGBMRegressor(seed=42) + local_regressor.fit(X, y, sample_weight=w) + s2 = local_regressor.score(X, y) + p2 = local_regressor.predict(X) + + # Scores should be the same + if output != 'dataframe': + assert_eq(s1, s2, atol=.01) + + # Predictions should be roughly the same + assert_eq(y, p1, rtol=1., atol=100.) + assert_eq(y, p2, rtol=1., atol=50.) + + +@pytest.mark.parametrize('output', data_output) +@pytest.mark.parametrize('alpha', [.1, .5, .9]) +def test_regressor_quantile(output, client, listen_port, alpha): + X, y, w, dX, dy, dw = _create_data('regression', output=output) + + dask_regressor = dlgbm.DaskLGBMRegressor(local_listen_port=listen_port, seed=42, objective='quantile', alpha=alpha) + dask_regressor = dask_regressor.fit(dX, dy, client=client, sample_weight=dw) + p1 = dask_regressor.predict(dX).compute() + q1 = np.count_nonzero(y < p1) / y.shape[0] + + local_regressor = lightgbm.LGBMRegressor(seed=42, objective='quantile', alpha=alpha) + local_regressor.fit(X, y, sample_weight=w) + p2 = local_regressor.predict(X) + q2 = np.count_nonzero(y < p2) / y.shape[0] + + # Quantiles should be right + np.testing.assert_allclose(q1, alpha, atol=0.2) + np.testing.assert_allclose(q2, alpha, atol=0.2) + + +def test_regressor_local_predict(client, listen_port): + X, y, w, dX, dy, dw = _create_data('regression', output='array') + + dask_regressor = dlgbm.DaskLGBMRegressor(local_listen_port=listen_port, seed=42) + dask_regressor = dask_regressor.fit(dX, dy, sample_weight=dw, client=client) + p1 = dask_regressor.predict(dX) + p2 = dask_regressor.to_local().predict(X) + s1 = r2_score(dy, p1) + p1 = p1.compute() + s2 = dask_regressor.to_local().score(X, y) + + # Predictions and scores should be the same + assert_eq(p1, p2) + assert_eq(s1, s2) + + +def test_build_network_params(): + workers_ips = [ + 'tcp://192.168.0.1:34545', + 'tcp://192.168.0.2:34346', + 'tcp://192.168.0.3:34347' + ] + + params = dlgbm._build_network_params(workers_ips, 'tcp://192.168.0.2:34346', 12400, 120) + exp_params = { + 'machines': '192.168.0.1:12400,192.168.0.2:12401,192.168.0.3:12402', + 'local_listen_port': 12401, + 'num_machines': len(workers_ips), + 'time_out': 120 + } + assert exp_params == params + + +@gen_cluster(client=True, timeout=None) +def test_errors(c, s, a, b): + def f(part): + raise Exception('foo') + + df = dd.demo.make_timeseries() + df = df.map_partitions(f, meta=df._meta) + with pytest.raises(Exception) as info: + yield dlgbm._train(c, df, df.x, params={}, model_factory=lightgbm.LGBMClassifier) + assert 'foo' in str(info.value)