diff --git a/python-package/lightgbm/basic.py b/python-package/lightgbm/basic.py index b7052243b0ee..6b6c95be891a 100644 --- a/python-package/lightgbm/basic.py +++ b/python-package/lightgbm/basic.py @@ -1,14 +1,15 @@ # coding: utf-8 """Wrapper for C API of LightGBM.""" -import copy import ctypes import json import os import warnings from collections import OrderedDict +from copy import deepcopy from functools import wraps from logging import Logger from tempfile import NamedTemporaryFile +from typing import Any, Dict import numpy as np import scipy.sparse @@ -352,6 +353,46 @@ def get(cls, *args): return ret +def _choose_param_value(main_param_name: str, params: Dict[str, Any], default_value: Any) -> Dict[str, Any]: + """Get a single parameter value, accounting for aliases. + + Parameters + ---------- + main_param_name : str + Name of the main parameter to get a value for. One of the keys of ``_ConfigAliases``. + params : dict + Dictionary of LightGBM parameters. + default_value : Any + Default value to use for the parameter, if none is found in ``params``. + + Returns + ------- + params : dict + A ``params`` dict with exactly one value for ``main_param_name``, and all aliases ``main_param_name`` removed. + If both ``main_param_name`` and one or more aliases for it are found, the value of ``main_param_name`` will be preferred. + """ + # avoid side effects on passed-in parameters + params = deepcopy(params) + + # find a value, and remove other aliases with .pop() + # prefer the value of 'main_param_name' if it exists, otherwise search the aliases + found_value = None + if main_param_name in params.keys(): + found_value = params[main_param_name] + + for param in _ConfigAliases.get(main_param_name): + val = params.pop(param, None) + if found_value is None and val is not None: + found_value = val + + if found_value is not None: + params[main_param_name] = found_value + else: + params[main_param_name] = default_value + + return params + + MAX_INT32 = (1 << 31) - 1 """Macro definition of data type in C API of LightGBM""" @@ -1052,7 +1093,7 @@ def __init__(self, data, label=None, reference=None, self.silent = silent self.feature_name = feature_name self.categorical_feature = categorical_feature - self.params = copy.deepcopy(params) + self.params = deepcopy(params) self.free_raw_data = free_raw_data self.used_indices = None self.need_slice = True @@ -1510,13 +1551,13 @@ def save_binary(self, filename): def _update_params(self, params): if not params: return self - params = copy.deepcopy(params) + params = deepcopy(params) def update(): if not self.params: self.params = params else: - self.params_back_up = copy.deepcopy(self.params) + self.params_back_up = deepcopy(self.params) self.params.update(params) if self.handle is None: @@ -1536,7 +1577,7 @@ def update(): def _reverse_update_params(self): if self.handle is None: - self.params = copy.deepcopy(self.params_back_up) + self.params = deepcopy(self.params_back_up) self.params_back_up = None return self @@ -2130,7 +2171,7 @@ def __init__(self, params=None, train_set=None, model_file=None, model_str=None, self.__set_objective_to_none = False self.best_iteration = -1 self.best_score = {} - params = {} if params is None else copy.deepcopy(params) + params = {} if params is None else deepcopy(params) # user can set verbose with params, it has higher priority if not any(verbose_alias in params for verbose_alias in _ConfigAliases.get("verbosity")) and silent: params["verbose"] = -1 @@ -2139,22 +2180,40 @@ def __init__(self, params=None, train_set=None, model_file=None, model_str=None, if not isinstance(train_set, Dataset): raise TypeError('Training data should be Dataset instance, met {}' .format(type(train_set).__name__)) - # set network if necessary - for alias in _ConfigAliases.get("machines"): - if alias in params: - machines = params[alias] - if isinstance(machines, str): - num_machines = len(machines.split(',')) - elif isinstance(machines, (list, set)): - num_machines = len(machines) - machines = ','.join(machines) - else: - raise ValueError("Invalid machines in params.") - self.set_network(machines, - local_listen_port=params.get("local_listen_port", 12400), - listen_time_out=params.get("listen_time_out", 120), - num_machines=params.setdefault("num_machines", num_machines)) - break + params = _choose_param_value( + main_param_name="machines", + params=params, + default_value=None + ) + # if "machines" is given, assume user wants to do distributed learning, and set up network + if params["machines"] is None: + params.pop("machines", None) + else: + machines = params["machines"] + if isinstance(machines, str): + num_machines_from_machine_list = len(machines.split(',')) + elif isinstance(machines, (list, set)): + num_machines_from_machine_list = len(machines) + machines = ','.join(machines) + else: + raise ValueError("Invalid machines in params.") + + params = _choose_param_value( + main_param_name="num_machines", + params=params, + default_value=num_machines_from_machine_list + ) + params = _choose_param_value( + main_param_name="local_listen_port", + params=params, + default_value=12400 + ) + self.set_network( + machines=machines, + local_listen_port=params["local_listen_port"], + listen_time_out=params.get("time_out", 120), + num_machines=params["num_machines"] + ) # construct booster object train_set.construct() # copy the parameters from train_set @@ -3056,7 +3115,7 @@ def predict(self, data, start_iteration=0, num_iteration=None, Prediction result. Can be sparse or a list of sparse objects (each element represents predictions for one class) for feature contributions (when ``pred_contrib=True``). """ - predictor = self._to_predictor(copy.deepcopy(kwargs)) + predictor = self._to_predictor(deepcopy(kwargs)) if num_iteration is None: if start_iteration <= 0: num_iteration = self.best_iteration @@ -3090,14 +3149,14 @@ def refit(self, data, label, decay_rate=0.9, **kwargs): """ if self.__set_objective_to_none: raise LightGBMError('Cannot refit due to null objective function.') - predictor = self._to_predictor(copy.deepcopy(kwargs)) + predictor = self._to_predictor(deepcopy(kwargs)) leaf_preds = predictor.predict(data, -1, pred_leaf=True) nrow, ncol = leaf_preds.shape out_is_linear = ctypes.c_bool(False) _safe_call(_LIB.LGBM_BoosterGetLinear( self.handle, ctypes.byref(out_is_linear))) - new_params = copy.deepcopy(self.params) + new_params = deepcopy(self.params) new_params["linear_tree"] = out_is_linear.value train_set = Dataset(data, label, silent=True, params=new_params) new_params['refit_decay_rate'] = decay_rate diff --git a/python-package/lightgbm/dask.py b/python-package/lightgbm/dask.py index e5503577d636..72173c70c76d 100644 --- a/python-package/lightgbm/dask.py +++ b/python-package/lightgbm/dask.py @@ -21,7 +21,7 @@ from dask import delayed from dask.distributed import Client, default_client, get_worker, wait -from .basic import _ConfigAliases, _LIB, _log_warning, _safe_call, LightGBMError +from .basic import _choose_param_value, _ConfigAliases, _LIB, _log_warning, _safe_call, LightGBMError from .compat import DASK_INSTALLED, PANDAS_INSTALLED, SKLEARN_INSTALLED from .sklearn import LGBMClassifier, LGBMRegressor, LGBMRanker @@ -196,6 +196,44 @@ def _train(client, data, label, params, model_factory, sample_weight=None, group """ params = deepcopy(params) + params = _choose_param_value( + main_param_name="local_listen_port", + params=params, + default_value=12400 + ) + + params = _choose_param_value( + main_param_name="tree_learner", + params=params, + default_value="data" + ) + allowed_tree_learners = { + 'data', + 'data_parallel', + 'feature', + 'feature_parallel', + 'voting', + 'voting_parallel' + } + if params["tree_learner"] not in allowed_tree_learners: + _log_warning('Parameter tree_learner set to %s, which is not allowed. Using "data" as default' % tree_learner) + params['tree_learner'] = 'data' + + if params['tree_learner'] not in {'data', 'data_parallel'}: + _log_warning( + 'Support for tree_learner %s in lightgbm.dask is experimental and may break in a future release. \n' + 'Use "data" for a stable, well-tested interface.' % params['tree_learner'] + ) + + # Some passed-in parameters can be removed: + # * 'machines': constructed automatically from Dask worker list + # * 'machine_list_filename': not relevant for the Dask interface + # * 'num_machines': set automatically from Dask worker list + # * 'num_threads': overridden to match nthreads on each Dask process + for param_name in ['machines', 'machine_list_filename', 'num_machines', 'num_threads']: + for param_alias in _ConfigAliases.get(param_name): + params.pop(param_alias, None) + # Split arrays/dataframes into parts. Arrange parts into dicts to enforce co-locality data_parts = _split_to_parts(data=data, is_matrix=True) label_parts = _split_to_parts(data=label, is_matrix=False) @@ -230,65 +268,15 @@ def _train(client, data, label, params, model_factory, sample_weight=None, group master_worker = next(iter(worker_map)) worker_ncores = client.ncores() - tree_learner = None - for tree_learner_param in _ConfigAliases.get('tree_learner'): - tree_learner = params.get(tree_learner_param) - if tree_learner is not None: - params['tree_learner'] = tree_learner - break - - allowed_tree_learners = { - 'data', - 'data_parallel', - 'feature', - 'feature_parallel', - 'voting', - 'voting_parallel' - } - if tree_learner is None: - _log_warning('Parameter tree_learner not set. Using "data" as default') - params['tree_learner'] = 'data' - elif tree_learner.lower() not in allowed_tree_learners: - _log_warning('Parameter tree_learner set to %s, which is not allowed. Using "data" as default' % tree_learner) - params['tree_learner'] = 'data' - - if params['tree_learner'] not in {'data', 'data_parallel'}: - _log_warning( - 'Support for tree_learner %s in lightgbm.dask is experimental and may break in a future release. Use "data" for a stable, well-tested interface.' % params['tree_learner'] - ) - - local_listen_port = 12400 - for port_param in _ConfigAliases.get('local_listen_port'): - val = params.get(port_param) - if val is not None: - local_listen_port = val - break - # find an open port on each worker. note that multiple workers can run # on the same machine, so this needs to ensure that each one gets its # own port worker_address_to_port = _find_ports_for_workers( client=client, worker_addresses=worker_map.keys(), - local_listen_port=local_listen_port + local_listen_port=params["local_listen_port"] ) - # num_threads is set below, so remove it and all aliases of it from params - for num_thread_alias in _ConfigAliases.get('num_threads'): - params.pop(num_thread_alias, None) - - # machines is constructed manually, so remove it and all aliases of it from params - for machine_alias in _ConfigAliases.get('machines'): - params.pop(machine_alias, None) - - # machines is constructed manually, so remove machine_list_filename and all aliases of it from params - for machine_list_filename_alias in _ConfigAliases.get('machine_list_filename'): - params.pop(machine_list_filename_alias, None) - - # machines is constructed manually, so remove num_machines and all aliases of it from params - for num_machine_alias in _ConfigAliases.get('num_machines'): - params.pop(num_machine_alias, None) - # Tell each worker to train on the parts that it has locally futures_classifiers = [ client.submit( diff --git a/tests/python_package_test/test_basic.py b/tests/python_package_test/test_basic.py index 7cc349ded449..c48fc0041300 100644 --- a/tests/python_package_test/test_basic.py +++ b/tests/python_package_test/test_basic.py @@ -329,3 +329,49 @@ def check_asserts(data): lgb_data.set_init_score(sequence) lgb_data.set_feature_name(feature_names) check_asserts(lgb_data) + + +def test_choose_param_value(): + + original_params = { + "local_listen_port": 1234, + "port": 2222, + "metric": "auc", + "num_trees": 81 + } + + # should resolve duplicate aliases, and prefer the main parameter + params = lgb.basic._choose_param_value( + main_param_name="local_listen_port", + params=original_params, + default_value=5555 + ) + assert params["local_listen_port"] == 1234 + assert "port" not in params + + # should choose a value from an alias and set that value on main param + # if only an alias is used + params = lgb.basic._choose_param_value( + main_param_name="num_iterations", + params=params, + default_value=17 + ) + assert params["num_iterations"] == 81 + assert "num_trees" not in params + + # should use the default if main param and aliases are missing + params = lgb.basic._choose_param_value( + main_param_name="learning_rate", + params=params, + default_value=0.789 + ) + assert params["learning_rate"] == 0.789 + + # all changes should be made on copies and not modify the original + expected_params = { + "local_listen_port": 1234, + "port": 2222, + "metric": "auc", + "num_trees": 81 + } + assert original_params == expected_params