From 083e94039284dd6f49e711550ee68327719495e2 Mon Sep 17 00:00:00 2001 From: James Lamb Date: Mon, 18 Jan 2021 22:49:55 -0600 Subject: [PATCH 01/11] [dask] allow parameter aliases for tree_learner and local_listen_port (fixes #3671) --- python-package/lightgbm/basic.py | 7 ++++++ python-package/lightgbm/dask.py | 32 ++++++++++++++++++++++---- tests/python_package_test/test_dask.py | 11 +++++---- 3 files changed, 41 insertions(+), 9 deletions(-) diff --git a/python-package/lightgbm/basic.py b/python-package/lightgbm/basic.py index 5ab1f7128b08..7cbf0d4205ee 100644 --- a/python-package/lightgbm/basic.py +++ b/python-package/lightgbm/basic.py @@ -238,6 +238,9 @@ class _ConfigAliases: "sparse"}, "label_column": {"label_column", "label"}, + "local_listen_port": {"local_listen_port", + "local_port", + "port"}, "machines": {"machines", "workers", "nodes"}, @@ -261,6 +264,10 @@ class _ConfigAliases: "application"}, "pre_partition": {"pre_partition", "is_pre_partition"}, + "tree_learner": {"tree_learner", + "tree", + "tree_type", + "tree_learner_type"}, "two_round": {"two_round", "two_round_loading", "use_two_round_loading"}, diff --git a/python-package/lightgbm/dask.py b/python-package/lightgbm/dask.py index fb8b06077e70..750e9081746e 100644 --- a/python-package/lightgbm/dask.py +++ b/python-package/lightgbm/dask.py @@ -17,7 +17,7 @@ from dask import delayed from dask.distributed import Client, default_client, get_worker, wait -from .basic import _LIB, _safe_call +from .basic import _ConfigAliases, _LIB, _safe_call from .sklearn import LGBMClassifier, LGBMRegressor import scipy.sparse as ss @@ -197,15 +197,37 @@ def _train(client, data, label, params, model_factory, weight=None, **kwargs): master_worker = next(iter(worker_map)) worker_ncores = client.ncores() - if 'tree_learner' not in params or params['tree_learner'].lower() not in {'data', 'feature', 'voting'}: - logger.warning('Parameter tree_learner not set or set to incorrect value ' - '(%s), using "data" as default', params.get("tree_learner", None)) + tree_learner = None + for tree_learner_param in _ConfigAliases.get('tree_learner'): + tree_learner = params.get(tree_learner_param) + if tree_learner is not None: + break + + allowed_tree_learners = { + 'data', + 'data_parallel', + 'feature', + 'feature_parallel', + 'voting', + 'voting_parallel' + } + if tree_learner is None: + logger.warning('Parameter tree_learner not set. Using "data" as default"') params['tree_learner'] = 'data' + elif tree_learner.lower() not in allowed_tree_learners: + logger.warning('Parameter tree_learner set to %s, which is not allowed. Using "data" as default' % tree_learner) + params['tree_learner'] = 'data' + + local_listen_port = 12400 + for port_param in _ConfigAliases.get('local_listen_port'): + val = params.get(port_param) + if val is not None: + local_listen_port = val + break # find an open port on each worker. note that multiple workers can run # on the same machine, so this needs to ensure that each one gets its # own port - local_listen_port = params.get('local_listen_port', 12400) worker_address_to_port = _find_ports_for_workers( client=client, worker_addresses=worker_map.keys(), diff --git a/tests/python_package_test/test_dask.py b/tests/python_package_test/test_dask.py index 901584dafd9c..e556cd9c89ca 100644 --- a/tests/python_package_test/test_dask.py +++ b/tests/python_package_test/test_dask.py @@ -124,7 +124,7 @@ def test_classifier_local_predict(client, listen_port): dask_classifier = dlgbm.DaskLGBMClassifier( time_out=5, - local_listen_port=listen_port, + local_port=listen_port, n_estimators=10, num_leaves=10 ) @@ -148,7 +148,8 @@ def test_regressor(output, client, listen_port): time_out=5, local_listen_port=listen_port, seed=42, - num_leaves=10 + num_leaves=10, + tree='data' ) dask_regressor = dask_regressor.fit(dX, dy, client=client, sample_weight=dw) p1 = dask_regressor.predict(dX) @@ -181,7 +182,8 @@ def test_regressor_quantile(output, client, listen_port, alpha): objective='quantile', alpha=alpha, n_estimators=10, - num_leaves=10 + num_leaves=10, + tree_learner_type='data_parallel' ) dask_regressor = dask_regressor.fit(dX, dy, client=client, sample_weight=dw) p1 = dask_regressor.predict(dX).compute() @@ -210,7 +212,8 @@ def test_regressor_local_predict(client, listen_port): local_listen_port=listen_port, seed=42, n_estimators=10, - num_leaves=10 + num_leaves=10, + tree_type='data' ) dask_regressor = dask_regressor.fit(dX, dy, sample_weight=dw, client=client) p1 = dask_regressor.predict(dX) From 0f6331384dab36170d405b0ed8febf87c2689517 Mon Sep 17 00:00:00 2001 From: James Lamb Date: Mon, 18 Jan 2021 23:14:09 -0600 Subject: [PATCH 02/11] num_thread too --- python-package/lightgbm/basic.py | 5 +++++ python-package/lightgbm/dask.py | 7 +++++++ 2 files changed, 12 insertions(+) diff --git a/python-package/lightgbm/basic.py b/python-package/lightgbm/basic.py index 7cbf0d4205ee..565a45e55846 100644 --- a/python-package/lightgbm/basic.py +++ b/python-package/lightgbm/basic.py @@ -258,6 +258,11 @@ class _ConfigAliases: "num_rounds", "num_boost_round", "n_estimators"}, + "num_threads": {"num_threads", + "num_thread", + "nthread", + "nthreads", + "n_jobs"}, "objective": {"objective", "objective_type", "app", diff --git a/python-package/lightgbm/dask.py b/python-package/lightgbm/dask.py index 750e9081746e..924471d3dfe2 100644 --- a/python-package/lightgbm/dask.py +++ b/python-package/lightgbm/dask.py @@ -7,6 +7,7 @@ import logging import socket from collections import defaultdict +from copy import deepcopy from typing import Dict, Iterable from urllib.parse import urlparse @@ -170,6 +171,8 @@ def _train(client, data, label, params, model_factory, weight=None, **kwargs): sample_weight : array-like of shape = [n_samples] or None, optional (default=None) Weights of training data. """ + params = deepcopy(params) + # Split arrays/dataframes into parts. Arrange parts into tuples to enforce co-locality data_parts = _split_to_parts(data, is_matrix=True) label_parts = _split_to_parts(label, is_matrix=False) @@ -234,6 +237,10 @@ def _train(client, data, label, params, model_factory, weight=None, **kwargs): local_listen_port=local_listen_port ) + # num_threads is set below, so remove it and all aliases of it from params + for num_thread_alias in _ConfigAliases.get('num_threads'): + _ = params.pop(num_thread_alias, None) + # Tell each worker to train on the parts that it has locally futures_classifiers = [client.submit(_train_part, model_factory=model_factory, From 3778fa7a9bafb8f94dbefa103cad4b1e47cddfed Mon Sep 17 00:00:00 2001 From: James Lamb Date: Tue, 19 Jan 2021 15:09:56 -0600 Subject: [PATCH 03/11] Apply suggestions from code review Co-authored-by: Nikita Titov --- python-package/lightgbm/dask.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python-package/lightgbm/dask.py b/python-package/lightgbm/dask.py index aed0e71a5d72..ba84edca87c2 100644 --- a/python-package/lightgbm/dask.py +++ b/python-package/lightgbm/dask.py @@ -215,7 +215,7 @@ def _train(client, data, label, params, model_factory, weight=None, **kwargs): 'voting_parallel' } if tree_learner is None: - logger.warning('Parameter tree_learner not set. Using "data" as default"') + logger.warning('Parameter tree_learner not set. Using "data" as default') params['tree_learner'] = 'data' elif tree_learner.lower() not in allowed_tree_learners: logger.warning('Parameter tree_learner set to %s, which is not allowed. Using "data" as default' % tree_learner) @@ -239,7 +239,7 @@ def _train(client, data, label, params, model_factory, weight=None, **kwargs): # num_threads is set below, so remove it and all aliases of it from params for num_thread_alias in _ConfigAliases.get('num_threads'): - _ = params.pop(num_thread_alias, None) + params.pop(num_thread_alias, None) # Tell each worker to train on the parts that it has locally futures_classifiers = [client.submit(_train_part, From db84675ef0a66f271460f6fcfb7edcf506b920ea Mon Sep 17 00:00:00 2001 From: James Lamb Date: Wed, 20 Jan 2021 10:34:56 -0600 Subject: [PATCH 04/11] empty commit From a9cacc1d381fd72d3470db2a570fe16a8fbe9b34 Mon Sep 17 00:00:00 2001 From: James Lamb Date: Mon, 25 Jan 2021 00:04:25 -0600 Subject: [PATCH 05/11] add _choose_param_value --- python-package/lightgbm/basic.py | 67 +++++++++++++++++++-- python-package/lightgbm/dask.py | 80 +++++++++++-------------- tests/python_package_test/test_basic.py | 46 ++++++++++++++ 3 files changed, 141 insertions(+), 52 deletions(-) diff --git a/python-package/lightgbm/basic.py b/python-package/lightgbm/basic.py index f741f03de19e..2055e7bba05f 100644 --- a/python-package/lightgbm/basic.py +++ b/python-package/lightgbm/basic.py @@ -6,9 +6,11 @@ import os import warnings from collections import OrderedDict +from copy import deepcopy from functools import wraps from logging import Logger from tempfile import NamedTemporaryFile +from typing import Any, Dict import numpy as np import scipy.sparse @@ -352,6 +354,46 @@ def get(cls, *args): return ret +def _choose_param_value(main_param_name: str, params: Dict[str, Any], default_value: Any): + """Get a single parameter value, accounting for aliases. + + Parameters + ---------- + main_param_name : str + Name of the main parameter to get a value for. One of the keys of ``_ConfigAliases``. + params : dict + Dictionary of LightGBM parameters. + default_value : Any + Default value to use for the parameter, if none is found in ``params`` + + Returns + ------- + params : dict + A ``params`` dict with exactly one value for ``main_param_name``, and all aliases ``main_param_name`` removed. + If both ``main_param_name`` and one or more aliases for it are found, the value of ``main_param_name`` will be preferred. + """ + # avoid side effects on passed-in parameters + params = deepcopy(params) + + # find a value, and remove other aliases with .pop() + # prefer the value of 'main_param_name' if it exists, otherwise search the aliases + found_value = None + if main_param_name in params.keys(): + found_value = params[main_param_name] + + for param in _ConfigAliases.get(main_param_name): + val = params.pop(param, None) + if found_value is None and val is not None: + found_value = val + + if found_value is not None: + params[main_param_name] = found_value + else: + params[main_param_name] = default_value + + return params + + MAX_INT32 = (1 << 31) - 1 """Macro definition of data type in C API of LightGBM""" @@ -2144,16 +2186,29 @@ def __init__(self, params=None, train_set=None, model_file=None, model_str=None, if alias in params: machines = params[alias] if isinstance(machines, str): - num_machines = len(machines.split(',')) + num_machines_from_machine_list = len(machines.split(',')) elif isinstance(machines, (list, set)): - num_machines = len(machines) + num_machines_from_machine_list = len(machines) machines = ','.join(machines) else: raise ValueError("Invalid machines in params.") - self.set_network(machines, - local_listen_port=params.get("local_listen_port", 12400), - listen_time_out=params.get("listen_time_out", 120), - num_machines=params.setdefault("num_machines", num_machines)) + + params = _choose_param_value( + main_param_name="num_machines", + params=params, + default_value=num_machines_from_machine_list + ) + params = _choose_param_value( + main_param_name="local_listen_port", + params=params, + default_value=12400 + ) + self.set_network( + machines=machines, + local_listen_port=params["local_listen_port"], + listen_time_out=params.get("listen_time_out", 120), + num_machines=params["num_machines"] + ) break # construct booster object train_set.construct() diff --git a/python-package/lightgbm/dask.py b/python-package/lightgbm/dask.py index 4acbf10702d7..3173804f199e 100644 --- a/python-package/lightgbm/dask.py +++ b/python-package/lightgbm/dask.py @@ -21,7 +21,7 @@ from dask import delayed from dask.distributed import Client, default_client, get_worker, wait -from .basic import _ConfigAliases, _LIB, _log_warning, _safe_call, LightGBMError +from .basic import _choose_param_value, _ConfigAliases, _LIB, _log_warning, _safe_call, LightGBMError from .compat import DASK_INSTALLED, PANDAS_INSTALLED, SKLEARN_INSTALLED from .sklearn import LGBMClassifier, LGBMRegressor, LGBMRanker @@ -197,6 +197,38 @@ def _train(client, data, label, params, model_factory, sample_weight=None, group """ params = deepcopy(params) + params = _choose_param_value( + main_param_name="local_listen_port", + params=params, + default_value=12400 + ) + + params = _choose_param_value( + main_param_name="tree_learner", + params=params, + default_value="data" + ) + allowed_tree_learners = { + 'data', + 'data_parallel', + 'feature', + 'feature_parallel', + 'voting', + 'voting_parallel' + } + if params["tree_learner"] not in allowed_tree_learners: + _log_warning('Parameter tree_learner set to %s, which is not allowed. Using "data" as default' % tree_learner) + params['tree_learner'] = 'data' + + # Some passed-inparameters can be removed: + # * 'machines': constructed automatically from Dask worker list + # * 'machine_list_filename': not relevant for the Dask interface + # * 'num_machines': set automatically from Dask worker list + # * 'num_threads': overridden to match nthreads on each Dask process + for param_name in ['machines', 'machine_list_filename', 'num_machines', 'num_threads']: + for param_alias in _ConfigAliases.get(param_name): + params.pop(param_alias, None) + # Split arrays/dataframes into parts. Arrange parts into tuples to enforce co-locality data_parts = _split_to_parts(data=data, is_matrix=True) label_parts = _split_to_parts(data=label, is_matrix=False) @@ -240,59 +272,15 @@ def _train(client, data, label, params, model_factory, sample_weight=None, group master_worker = next(iter(worker_map)) worker_ncores = client.ncores() - tree_learner = None - for tree_learner_param in _ConfigAliases.get('tree_learner'): - tree_learner = params.get(tree_learner_param) - if tree_learner is not None: - break - - allowed_tree_learners = { - 'data', - 'data_parallel', - 'feature', - 'feature_parallel', - 'voting', - 'voting_parallel' - } - if tree_learner is None: - _log_warning('Parameter tree_learner not set. Using "data" as default') - params['tree_learner'] = 'data' - elif tree_learner.lower() not in allowed_tree_learners: - _log_warning('Parameter tree_learner set to %s, which is not allowed. Using "data" as default' % tree_learner) - params['tree_learner'] = 'data' - - local_listen_port = 12400 - for port_param in _ConfigAliases.get('local_listen_port'): - val = params.get(port_param) - if val is not None: - local_listen_port = val - break - # find an open port on each worker. note that multiple workers can run # on the same machine, so this needs to ensure that each one gets its # own port worker_address_to_port = _find_ports_for_workers( client=client, worker_addresses=worker_map.keys(), - local_listen_port=local_listen_port + local_listen_port=params["local_listen_port"] ) - # num_threads is set below, so remove it and all aliases of it from params - for num_thread_alias in _ConfigAliases.get('num_threads'): - params.pop(num_thread_alias, None) - - # machines is constructed manually, so remove it and all aliases of it from params - for machine_alias in _ConfigAliases.get('machines'): - params.pop(machine_alias, None) - - # machines is constructed manually, so remove machine_list_filename and all aliases of it from params - for machine_list_filename_alias in _ConfigAliases.get('machine_list_filename'): - params.pop(machine_list_filename_alias, None) - - # machines is constructed manually, so remove num_machines and all aliases of it from params - for num_machine_alias in _ConfigAliases.get('num_machines'): - params.pop(num_machine_alias, None) - # Tell each worker to train on the parts that it has locally futures_classifiers = [ client.submit( diff --git a/tests/python_package_test/test_basic.py b/tests/python_package_test/test_basic.py index 7cc349ded449..46cb8133514b 100644 --- a/tests/python_package_test/test_basic.py +++ b/tests/python_package_test/test_basic.py @@ -329,3 +329,49 @@ def check_asserts(data): lgb_data.set_init_score(sequence) lgb_data.set_feature_name(feature_names) check_asserts(lgb_data) + + +def test_choose_param_value(): + + original_params = { + "local_listen_port": 1234, + "port": 2222, + "metric": "auc", + "num_trees": 81 + } + + # should resolve duplicate aliases, and prefer the main parameter + params = lgb.basic._choose_param_value( + main_param_name="local_listen_port", + params=original_params, + default_value=5555 + ) + assert params["local_listen_port"] == 1234 + assert "port" not in params + + # should choose a value from an alias and set that value on main param + # if only an alias is used + params = lgb.basic._choose_param_value( + main_param_name="num_iterations", + params=params, + default_value=17 + ) + assert params["num_iterations"] == 81 + assert "random_state" not in params + + # should use the default if main param and aliases are missing + params = lgb.basic._choose_param_value( + main_param_name="learning_rate", + params=params, + default_value=0.789 + ) + assert params["learning_rate"] == 0.789 + + # all changes should be made on copies and not modify the original + expected_params = { + "local_listen_port": 1234, + "port": 2222, + "metric": "auc", + "num_trees": 81 + } + assert original_params == expected_params From 033eb45d85af4ae238e3afe41eb9fec1c16b89ed Mon Sep 17 00:00:00 2001 From: James Lamb Date: Mon, 25 Jan 2021 00:06:44 -0600 Subject: [PATCH 06/11] revert param order change --- python-package/lightgbm/basic.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/python-package/lightgbm/basic.py b/python-package/lightgbm/basic.py index 2055e7bba05f..411c26fa3b91 100644 --- a/python-package/lightgbm/basic.py +++ b/python-package/lightgbm/basic.py @@ -300,13 +300,13 @@ class _ConfigAliases: "local_listen_port": {"local_listen_port", "local_port", "port"}, - "machines": {"machines", - "workers", - "nodes"}, "machine_list_filename": {"machine_list_filename", "machine_list_file", "machine_list", "mlist"}, + "machines": {"machines", + "workers", + "nodes"}, "metric": {"metric", "metrics", "metric_types"}, @@ -354,7 +354,7 @@ def get(cls, *args): return ret -def _choose_param_value(main_param_name: str, params: Dict[str, Any], default_value: Any): +def _choose_param_value(main_param_name: str, params: Dict[str, Any], default_value: Any) -> Dict[str, Any]: """Get a single parameter value, accounting for aliases. Parameters From 7816cb583fada3b204d994b1a838f01a9ab055df Mon Sep 17 00:00:00 2001 From: James Lamb Date: Mon, 25 Jan 2021 12:35:21 -0600 Subject: [PATCH 07/11] Apply suggestions from code review Co-authored-by: Nikita Titov --- python-package/lightgbm/basic.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python-package/lightgbm/basic.py b/python-package/lightgbm/basic.py index 411c26fa3b91..51a33c3b45e3 100644 --- a/python-package/lightgbm/basic.py +++ b/python-package/lightgbm/basic.py @@ -364,7 +364,7 @@ def _choose_param_value(main_param_name: str, params: Dict[str, Any], default_va params : dict Dictionary of LightGBM parameters. default_value : Any - Default value to use for the parameter, if none is found in ``params`` + Default value to use for the parameter, if none is found in ``params``. Returns ------- From d0ff5c350de77580c57ecdbcf074d9f8abac4aa1 Mon Sep 17 00:00:00 2001 From: James Lamb Date: Mon, 25 Jan 2021 12:35:46 -0600 Subject: [PATCH 08/11] Update python-package/lightgbm/dask.py Co-authored-by: Nikita Titov --- python-package/lightgbm/dask.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python-package/lightgbm/dask.py b/python-package/lightgbm/dask.py index ad6ade860e2f..d0829b918d91 100644 --- a/python-package/lightgbm/dask.py +++ b/python-package/lightgbm/dask.py @@ -225,7 +225,7 @@ def _train(client, data, label, params, model_factory, sample_weight=None, group 'Support for tree_learner %s in lightgbm.dask is experimental and may break in a future release. Use "data" for a stable, well-tested interface.' % params['tree_learner'] ) - # Some passed-inparameters can be removed: + # Some passed-in parameters can be removed: # * 'machines': constructed automatically from Dask worker list # * 'machine_list_filename': not relevant for the Dask interface # * 'num_machines': set automatically from Dask worker list From 481a5b4faef0eae39ebf5085a24fba85b90a613f Mon Sep 17 00:00:00 2001 From: James Lamb Date: Mon, 25 Jan 2021 12:36:43 -0600 Subject: [PATCH 09/11] just import deepcopy --- python-package/lightgbm/basic.py | 17 ++++++++--------- 1 file changed, 8 insertions(+), 9 deletions(-) diff --git a/python-package/lightgbm/basic.py b/python-package/lightgbm/basic.py index 51a33c3b45e3..17a08b46014a 100644 --- a/python-package/lightgbm/basic.py +++ b/python-package/lightgbm/basic.py @@ -1,6 +1,5 @@ # coding: utf-8 """Wrapper for C API of LightGBM.""" -import copy import ctypes import json import os @@ -1094,7 +1093,7 @@ def __init__(self, data, label=None, reference=None, self.silent = silent self.feature_name = feature_name self.categorical_feature = categorical_feature - self.params = copy.deepcopy(params) + self.params = deepcopy(params) self.free_raw_data = free_raw_data self.used_indices = None self.need_slice = True @@ -1552,13 +1551,13 @@ def save_binary(self, filename): def _update_params(self, params): if not params: return self - params = copy.deepcopy(params) + params = deepcopy(params) def update(): if not self.params: self.params = params else: - self.params_back_up = copy.deepcopy(self.params) + self.params_back_up = deepcopy(self.params) self.params.update(params) if self.handle is None: @@ -1578,7 +1577,7 @@ def update(): def _reverse_update_params(self): if self.handle is None: - self.params = copy.deepcopy(self.params_back_up) + self.params = deepcopy(self.params_back_up) self.params_back_up = None return self @@ -2172,7 +2171,7 @@ def __init__(self, params=None, train_set=None, model_file=None, model_str=None, self.__set_objective_to_none = False self.best_iteration = -1 self.best_score = {} - params = {} if params is None else copy.deepcopy(params) + params = {} if params is None else deepcopy(params) # user can set verbose with params, it has higher priority if not any(verbose_alias in params for verbose_alias in _ConfigAliases.get("verbosity")) and silent: params["verbose"] = -1 @@ -3111,7 +3110,7 @@ def predict(self, data, start_iteration=0, num_iteration=None, Prediction result. Can be sparse or a list of sparse objects (each element represents predictions for one class) for feature contributions (when ``pred_contrib=True``). """ - predictor = self._to_predictor(copy.deepcopy(kwargs)) + predictor = self._to_predictor(deepcopy(kwargs)) if num_iteration is None: if start_iteration <= 0: num_iteration = self.best_iteration @@ -3145,14 +3144,14 @@ def refit(self, data, label, decay_rate=0.9, **kwargs): """ if self.__set_objective_to_none: raise LightGBMError('Cannot refit due to null objective function.') - predictor = self._to_predictor(copy.deepcopy(kwargs)) + predictor = self._to_predictor(deepcopy(kwargs)) leaf_preds = predictor.predict(data, -1, pred_leaf=True) nrow, ncol = leaf_preds.shape out_is_linear = ctypes.c_bool(False) _safe_call(_LIB.LGBM_BoosterGetLinear( self.handle, ctypes.byref(out_is_linear))) - new_params = copy.deepcopy(self.params) + new_params = deepcopy(self.params) new_params["linear_tree"] = out_is_linear.value train_set = Dataset(data, label, silent=True, params=new_params) new_params['refit_decay_rate'] = decay_rate From 7afd37bda9f5d3490af891adbf943e8a8edaf722 Mon Sep 17 00:00:00 2001 From: James Lamb Date: Mon, 25 Jan 2021 12:41:24 -0600 Subject: [PATCH 10/11] remove machines aliases --- python-package/lightgbm/basic.py | 63 +++++++++++++++++--------------- 1 file changed, 34 insertions(+), 29 deletions(-) diff --git a/python-package/lightgbm/basic.py b/python-package/lightgbm/basic.py index 17a08b46014a..d165ec840437 100644 --- a/python-package/lightgbm/basic.py +++ b/python-package/lightgbm/basic.py @@ -2180,35 +2180,40 @@ def __init__(self, params=None, train_set=None, model_file=None, model_str=None, if not isinstance(train_set, Dataset): raise TypeError('Training data should be Dataset instance, met {}' .format(type(train_set).__name__)) - # set network if necessary - for alias in _ConfigAliases.get("machines"): - if alias in params: - machines = params[alias] - if isinstance(machines, str): - num_machines_from_machine_list = len(machines.split(',')) - elif isinstance(machines, (list, set)): - num_machines_from_machine_list = len(machines) - machines = ','.join(machines) - else: - raise ValueError("Invalid machines in params.") - - params = _choose_param_value( - main_param_name="num_machines", - params=params, - default_value=num_machines_from_machine_list - ) - params = _choose_param_value( - main_param_name="local_listen_port", - params=params, - default_value=12400 - ) - self.set_network( - machines=machines, - local_listen_port=params["local_listen_port"], - listen_time_out=params.get("listen_time_out", 120), - num_machines=params["num_machines"] - ) - break + params = _choose_param_value( + main_param_name="machines", + params=params, + default_value=None + ) + # if "machines" is given, assume user wants to do distributed learning, and set up network + if params["machines"] is None: + params.pop("machines", None) + else: + machines = params["machines"] + if isinstance(machines, str): + num_machines_from_machine_list = len(machines.split(',')) + elif isinstance(machines, (list, set)): + num_machines_from_machine_list = len(machines) + machines = ','.join(machines) + else: + raise ValueError("Invalid machines in params.") + + params = _choose_param_value( + main_param_name="num_machines", + params=params, + default_value=num_machines_from_machine_list + ) + params = _choose_param_value( + main_param_name="local_listen_port", + params=params, + default_value=12400 + ) + self.set_network( + machines=machines, + local_listen_port=params["local_listen_port"], + listen_time_out=params.get("listen_time_out", 120), + num_machines=params["num_machines"] + ) # construct booster object train_set.construct() # copy the parameters from train_set From 1162b9e451a245a7bcb36e8cdf0324e8dcf1c2d2 Mon Sep 17 00:00:00 2001 From: James Lamb Date: Tue, 26 Jan 2021 09:17:07 -0600 Subject: [PATCH 11/11] Apply suggestions from code review Co-authored-by: Nikita Titov --- python-package/lightgbm/basic.py | 2 +- python-package/lightgbm/dask.py | 3 ++- tests/python_package_test/test_basic.py | 2 +- 3 files changed, 4 insertions(+), 3 deletions(-) diff --git a/python-package/lightgbm/basic.py b/python-package/lightgbm/basic.py index d165ec840437..6b6c95be891a 100644 --- a/python-package/lightgbm/basic.py +++ b/python-package/lightgbm/basic.py @@ -2211,7 +2211,7 @@ def __init__(self, params=None, train_set=None, model_file=None, model_str=None, self.set_network( machines=machines, local_listen_port=params["local_listen_port"], - listen_time_out=params.get("listen_time_out", 120), + listen_time_out=params.get("time_out", 120), num_machines=params["num_machines"] ) # construct booster object diff --git a/python-package/lightgbm/dask.py b/python-package/lightgbm/dask.py index ffbfaf1b4772..72173c70c76d 100644 --- a/python-package/lightgbm/dask.py +++ b/python-package/lightgbm/dask.py @@ -221,7 +221,8 @@ def _train(client, data, label, params, model_factory, sample_weight=None, group if params['tree_learner'] not in {'data', 'data_parallel'}: _log_warning( - 'Support for tree_learner %s in lightgbm.dask is experimental and may break in a future release. Use "data" for a stable, well-tested interface.' % params['tree_learner'] + 'Support for tree_learner %s in lightgbm.dask is experimental and may break in a future release. \n' + 'Use "data" for a stable, well-tested interface.' % params['tree_learner'] ) # Some passed-in parameters can be removed: diff --git a/tests/python_package_test/test_basic.py b/tests/python_package_test/test_basic.py index 46cb8133514b..c48fc0041300 100644 --- a/tests/python_package_test/test_basic.py +++ b/tests/python_package_test/test_basic.py @@ -357,7 +357,7 @@ def test_choose_param_value(): default_value=17 ) assert params["num_iterations"] == 81 - assert "random_state" not in params + assert "num_trees" not in params # should use the default if main param and aliases are missing params = lgb.basic._choose_param_value(