diff --git a/python-package/lightgbm/basic.py b/python-package/lightgbm/basic.py index 9dbf443a79cc..f52776f320ca 100644 --- a/python-package/lightgbm/basic.py +++ b/python-package/lightgbm/basic.py @@ -238,6 +238,9 @@ class _ConfigAliases: "sparse"}, "label_column": {"label_column", "label"}, + "local_listen_port": {"local_listen_port", + "local_port", + "port"}, "machines": {"machines", "workers", "nodes"}, @@ -255,12 +258,21 @@ class _ConfigAliases: "num_rounds", "num_boost_round", "n_estimators"}, + "num_threads": {"num_threads", + "num_thread", + "nthread", + "nthreads", + "n_jobs"}, "objective": {"objective", "objective_type", "app", "application"}, "pre_partition": {"pre_partition", "is_pre_partition"}, + "tree_learner": {"tree_learner", + "tree", + "tree_type", + "tree_learner_type"}, "two_round": {"two_round", "two_round_loading", "use_two_round_loading"}, diff --git a/python-package/lightgbm/dask.py b/python-package/lightgbm/dask.py index 0e40e4534339..ba84edca87c2 100644 --- a/python-package/lightgbm/dask.py +++ b/python-package/lightgbm/dask.py @@ -7,6 +7,7 @@ import logging import socket from collections import defaultdict +from copy import deepcopy from typing import Dict, Iterable from urllib.parse import urlparse @@ -19,7 +20,7 @@ from dask import delayed from dask.distributed import Client, default_client, get_worker, wait -from .basic import _LIB, _safe_call +from .basic import _ConfigAliases, _LIB, _safe_call from .sklearn import LGBMClassifier, LGBMRegressor logger = logging.getLogger(__name__) @@ -170,6 +171,8 @@ def _train(client, data, label, params, model_factory, weight=None, **kwargs): sample_weight : array-like of shape = [n_samples] or None, optional (default=None) Weights of training data. """ + params = deepcopy(params) + # Split arrays/dataframes into parts. Arrange parts into tuples to enforce co-locality data_parts = _split_to_parts(data, is_matrix=True) label_parts = _split_to_parts(label, is_matrix=False) @@ -197,21 +200,47 @@ def _train(client, data, label, params, model_factory, weight=None, **kwargs): master_worker = next(iter(worker_map)) worker_ncores = client.ncores() - if 'tree_learner' not in params or params['tree_learner'].lower() not in {'data', 'feature', 'voting'}: - logger.warning('Parameter tree_learner not set or set to incorrect value ' - '(%s), using "data" as default', params.get("tree_learner", None)) + tree_learner = None + for tree_learner_param in _ConfigAliases.get('tree_learner'): + tree_learner = params.get(tree_learner_param) + if tree_learner is not None: + break + + allowed_tree_learners = { + 'data', + 'data_parallel', + 'feature', + 'feature_parallel', + 'voting', + 'voting_parallel' + } + if tree_learner is None: + logger.warning('Parameter tree_learner not set. Using "data" as default') params['tree_learner'] = 'data' + elif tree_learner.lower() not in allowed_tree_learners: + logger.warning('Parameter tree_learner set to %s, which is not allowed. Using "data" as default' % tree_learner) + params['tree_learner'] = 'data' + + local_listen_port = 12400 + for port_param in _ConfigAliases.get('local_listen_port'): + val = params.get(port_param) + if val is not None: + local_listen_port = val + break # find an open port on each worker. note that multiple workers can run # on the same machine, so this needs to ensure that each one gets its # own port - local_listen_port = params.get('local_listen_port', 12400) worker_address_to_port = _find_ports_for_workers( client=client, worker_addresses=worker_map.keys(), local_listen_port=local_listen_port ) + # num_threads is set below, so remove it and all aliases of it from params + for num_thread_alias in _ConfigAliases.get('num_threads'): + params.pop(num_thread_alias, None) + # Tell each worker to train on the parts that it has locally futures_classifiers = [client.submit(_train_part, model_factory=model_factory, diff --git a/tests/python_package_test/test_dask.py b/tests/python_package_test/test_dask.py index 1a454f6c6c87..e793872ee4fb 100644 --- a/tests/python_package_test/test_dask.py +++ b/tests/python_package_test/test_dask.py @@ -124,7 +124,7 @@ def test_classifier_local_predict(client, listen_port): dask_classifier = dlgbm.DaskLGBMClassifier( time_out=5, - local_listen_port=listen_port, + local_port=listen_port, n_estimators=10, num_leaves=10 ) @@ -148,7 +148,8 @@ def test_regressor(output, client, listen_port): time_out=5, local_listen_port=listen_port, seed=42, - num_leaves=10 + num_leaves=10, + tree='data' ) dask_regressor = dask_regressor.fit(dX, dy, client=client, sample_weight=dw) p1 = dask_regressor.predict(dX) @@ -181,7 +182,8 @@ def test_regressor_quantile(output, client, listen_port, alpha): objective='quantile', alpha=alpha, n_estimators=10, - num_leaves=10 + num_leaves=10, + tree_learner_type='data_parallel' ) dask_regressor = dask_regressor.fit(dX, dy, client=client, sample_weight=dw) p1 = dask_regressor.predict(dX).compute() @@ -210,7 +212,8 @@ def test_regressor_local_predict(client, listen_port): local_listen_port=listen_port, seed=42, n_estimators=10, - num_leaves=10 + num_leaves=10, + tree_type='data' ) dask_regressor = dask_regressor.fit(dX, dy, sample_weight=dw, client=client) p1 = dask_regressor.predict(dX)