Skip to content

Commit

Permalink
review comments
Browse files Browse the repository at this point in the history
  • Loading branch information
SfinxCZ committed Dec 13, 2020
1 parent 8a77bd9 commit 5d2dcac
Show file tree
Hide file tree
Showing 4 changed files with 20 additions and 23 deletions.
2 changes: 1 addition & 1 deletion .ci/test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ if [[ $TASK == "if-else" ]]; then
exit 0
fi

conda install -q -y -n $CONDA_ENV joblib matplotlib numpy pandas psutil pytest python-graphviz scikit-learn scipy dask distributed dask-ml
conda install -q -y -n $CONDA_ENV dask dask-ml distributed joblib matplotlib numpy pandas psutil pytest python-graphviz scikit-learn scipy

if [[ $OS_NAME == "macos" ]] && [[ $COMPILER == "clang" ]]; then
# fix "OMP: Error #15: Initializing libiomp5.dylib, but found libomp.dylib already initialized." (OpenMP library conflict due to conda's MKL)
Expand Down
17 changes: 7 additions & 10 deletions python-package/lightgbm/dask.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
# coding: utf-8
"""Distributed training with LightGBM and Dask.distributed.
This module enables you to perform distributed training with LightGBM on Dask.Array and Dask.DataFrame collections.
Expand All @@ -13,15 +14,11 @@
from dask import dataframe as dd
from dask import delayed
from dask.distributed import default_client, get_worker, wait
from toolz import assoc, first

from .basic import _LIB, _safe_call
from .sklearn import LGBMClassifier as LocalLGBMClassifier, LGBMRegressor as LocalLGBMRegressor

try:
import scipy.sparse as ss
except ImportError:
ss = False
import scipy.sparse as ss

logger = logging.getLogger(__name__)

Expand All @@ -32,7 +29,7 @@ def _parse_host_port(address):


def _build_network_params(worker_addresses, local_worker_ip, local_listen_port, time_out):
"""Build network parameters suiltable for LightGBM C backend.
"""Build network parameters suitable for LightGBM C backend.
Parameters
----------
Expand Down Expand Up @@ -60,7 +57,7 @@ def _concat(seq):
return np.concatenate(seq, axis=0)
elif isinstance(seq[0], (pd.DataFrame, pd.Series)):
return pd.concat(seq, axis=0)
elif ss and isinstance(seq[0], ss.spmatrix):
elif isinstance(seq[0], ss.spmatrix):
return ss.vstack(seq, format='csr')
else:
raise TypeError('Data must be one of: numpy arrays, pandas dataframes, sparse matrices (from scipy). Got %s.' % str(type(seq[0])))
Expand Down Expand Up @@ -131,9 +128,9 @@ def _train(client, data, label, params, model_factory, weight=None, **kwargs):
who_has = client.who_has(parts)
worker_map = defaultdict(list)
for key, workers in who_has.items():
worker_map[first(workers)].append(key_to_part_dict[key])
worker_map[next(iter(workers))].append(key_to_part_dict[key])

master_worker = first(worker_map)
master_worker = next(iter(worker_map))
worker_ncores = client.ncores()

if 'tree_learner' not in params or params['tree_learner'].lower() not in {'data', 'feature', 'voting'}:
Expand All @@ -144,7 +141,7 @@ def _train(client, data, label, params, model_factory, weight=None, **kwargs):
# Tell each worker to train on the parts that it has locally
futures_classifiers = [client.submit(_train_part,
model_factory=model_factory,
params=assoc(params, 'num_threads', worker_ncores[worker]),
params={**params, 'num_threads': worker_ncores[worker]},
list_of_parts=list_of_parts,
worker_addresses=list(worker_map.keys()),
local_listen_port=params.get('local_listen_port', 12400),
Expand Down
1 change: 0 additions & 1 deletion python-package/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -350,7 +350,6 @@ def run(self):
'dask[dataframe]>=2.0.0'
'dask[distributed]>=2.0.0',
'pandas',
'toolz'
],
},
maintainer='Guolin Ke',
Expand Down
23 changes: 12 additions & 11 deletions tests/python_package_test/test_dask.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
# coding: utf-8
import os
import sys

Expand Down Expand Up @@ -70,16 +71,16 @@ def _create_data(objective, n_samples=100, centers=2, output='array', chunk_size
def test_classifier(output, centers, client, listen_port): # noqa
X, y, w, dX, dy, dw = _create_data('classification', output=output, centers=centers)

a = dlgbm.LGBMClassifier(time_out=5, local_listen_port=listen_port)
a = a.fit(dX, dy, sample_weight=dw, client=client)
p1 = a.predict(dX)
classifier_a = dlgbm.LGBMClassifier(time_out=5, local_listen_port=listen_port)
classifier_a = classifier_a.fit(dX, dy, sample_weight=dw, client=client)
p1 = classifier_a.predict(dX)
s1 = accuracy_score(dy, p1)
p1 = p1.compute()

b = lightgbm.LGBMClassifier()
b.fit(X, y, sample_weight=w)
p2 = b.predict(X)
s2 = b.score(X, y)
classifier_b = lightgbm.LGBMClassifier()
classifier_b.fit(X, y, sample_weight=w)
p2 = classifier_b.predict(X)
s2 = classifier_b.score(X, y)

assert_eq(s1, s2)

Expand Down Expand Up @@ -162,11 +163,11 @@ def test_regressor_quantile(output, client, listen_port, alpha): # noqa
q2 = np.count_nonzero(y < p2) / y.shape[0]

# Quantiles should be right
np.isclose(q1, alpha, atol=.1)
np.isclose(q2, alpha, atol=.1)
np.testing.assert_allclose(q1, alpha, atol=0.2)
np.testing.assert_allclose(q2, alpha, atol=0.2)


def test_regressor_local_predict(client, listen_port): # noqa
def test_regressor_local_predict(client, listen_port):
X, y, w, dX, dy, dw = _create_data('regression', output='array')

a = dlgbm.LGBMRegressor(local_listen_port=listen_port, seed=42)
Expand All @@ -179,7 +180,7 @@ def test_regressor_local_predict(client, listen_port): # noqa

# Predictions and scores should be the same
assert_eq(p1, p2)
np.isclose(s1, s2)
assert_eq(s1, s2)


def test_build_network_params():
Expand Down

0 comments on commit 5d2dcac

Please sign in to comment.