Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[dask] allow parameter aliases for local_listen_port, num_threads, tree_learner (fixes #3671) #3789

Merged
merged 6 commits into from
Jan 20, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions python-package/lightgbm/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,6 +238,9 @@ class _ConfigAliases:
"sparse"},
"label_column": {"label_column",
"label"},
"local_listen_port": {"local_listen_port",
"local_port",
"port"},
"machines": {"machines",
"workers",
"nodes"},
Expand All @@ -255,12 +258,21 @@ class _ConfigAliases:
"num_rounds",
"num_boost_round",
"n_estimators"},
"num_threads": {"num_threads",
"num_thread",
"nthread",
"nthreads",
"n_jobs"},
"objective": {"objective",
"objective_type",
"app",
"application"},
"pre_partition": {"pre_partition",
"is_pre_partition"},
"tree_learner": {"tree_learner",
"tree",
"tree_type",
"tree_learner_type"},
"two_round": {"two_round",
"two_round_loading",
"use_two_round_loading"},
Expand Down
39 changes: 34 additions & 5 deletions python-package/lightgbm/dask.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import logging
import socket
from collections import defaultdict
from copy import deepcopy
from typing import Dict, Iterable
from urllib.parse import urlparse

Expand All @@ -19,7 +20,7 @@
from dask import delayed
from dask.distributed import Client, default_client, get_worker, wait

from .basic import _LIB, _safe_call
from .basic import _ConfigAliases, _LIB, _safe_call
from .sklearn import LGBMClassifier, LGBMRegressor

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -170,6 +171,8 @@ def _train(client, data, label, params, model_factory, weight=None, **kwargs):
sample_weight : array-like of shape = [n_samples] or None, optional (default=None)
Weights of training data.
"""
params = deepcopy(params)

# Split arrays/dataframes into parts. Arrange parts into tuples to enforce co-locality
data_parts = _split_to_parts(data, is_matrix=True)
label_parts = _split_to_parts(label, is_matrix=False)
Expand Down Expand Up @@ -197,21 +200,47 @@ def _train(client, data, label, params, model_factory, weight=None, **kwargs):
master_worker = next(iter(worker_map))
worker_ncores = client.ncores()

if 'tree_learner' not in params or params['tree_learner'].lower() not in {'data', 'feature', 'voting'}:
logger.warning('Parameter tree_learner not set or set to incorrect value '
'(%s), using "data" as default', params.get("tree_learner", None))
tree_learner = None
for tree_learner_param in _ConfigAliases.get('tree_learner'):
tree_learner = params.get(tree_learner_param)
if tree_learner is not None:
break

allowed_tree_learners = {
'data',
'data_parallel',
'feature',
'feature_parallel',
'voting',
'voting_parallel'
}
if tree_learner is None:
logger.warning('Parameter tree_learner not set. Using "data" as default')
params['tree_learner'] = 'data'
elif tree_learner.lower() not in allowed_tree_learners:
logger.warning('Parameter tree_learner set to %s, which is not allowed. Using "data" as default' % tree_learner)
params['tree_learner'] = 'data'

local_listen_port = 12400
for port_param in _ConfigAliases.get('local_listen_port'):
val = params.get(port_param)
if val is not None:
local_listen_port = val
break

# find an open port on each worker. note that multiple workers can run
# on the same machine, so this needs to ensure that each one gets its
# own port
local_listen_port = params.get('local_listen_port', 12400)
worker_address_to_port = _find_ports_for_workers(
client=client,
worker_addresses=worker_map.keys(),
local_listen_port=local_listen_port
)

# num_threads is set below, so remove it and all aliases of it from params
StrikerRUS marked this conversation as resolved.
Show resolved Hide resolved
for num_thread_alias in _ConfigAliases.get('num_threads'):
params.pop(num_thread_alias, None)

# Tell each worker to train on the parts that it has locally
futures_classifiers = [client.submit(_train_part,
model_factory=model_factory,
Expand Down
11 changes: 7 additions & 4 deletions tests/python_package_test/test_dask.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ def test_classifier_local_predict(client, listen_port):

dask_classifier = dlgbm.DaskLGBMClassifier(
time_out=5,
local_listen_port=listen_port,
local_port=listen_port,
n_estimators=10,
num_leaves=10
)
Expand All @@ -148,7 +148,8 @@ def test_regressor(output, client, listen_port):
time_out=5,
local_listen_port=listen_port,
seed=42,
num_leaves=10
num_leaves=10,
tree='data'
)
dask_regressor = dask_regressor.fit(dX, dy, client=client, sample_weight=dw)
p1 = dask_regressor.predict(dX)
Expand Down Expand Up @@ -181,7 +182,8 @@ def test_regressor_quantile(output, client, listen_port, alpha):
objective='quantile',
alpha=alpha,
n_estimators=10,
num_leaves=10
num_leaves=10,
tree_learner_type='data_parallel'
)
dask_regressor = dask_regressor.fit(dX, dy, client=client, sample_weight=dw)
p1 = dask_regressor.predict(dX).compute()
Expand Down Expand Up @@ -210,7 +212,8 @@ def test_regressor_local_predict(client, listen_port):
local_listen_port=listen_port,
seed=42,
n_estimators=10,
num_leaves=10
num_leaves=10,
tree_type='data'
)
dask_regressor = dask_regressor.fit(dX, dy, sample_weight=dw, client=client)
p1 = dask_regressor.predict(dX)
Expand Down