Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[python-package] respect parameter aliases for network params #3813

Merged
merged 19 commits into from
Jan 26, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
109 changes: 84 additions & 25 deletions python-package/lightgbm/basic.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
# coding: utf-8
"""Wrapper for C API of LightGBM."""
import copy
import ctypes
import json
import os
import warnings
from collections import OrderedDict
from copy import deepcopy
StrikerRUS marked this conversation as resolved.
Show resolved Hide resolved
from functools import wraps
from logging import Logger
from tempfile import NamedTemporaryFile
from typing import Any, Dict

import numpy as np
import scipy.sparse
Expand Down Expand Up @@ -352,6 +353,46 @@ def get(cls, *args):
return ret


def _choose_param_value(main_param_name: str, params: Dict[str, Any], default_value: Any) -> Dict[str, Any]:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TBH, I don't like _choose_param_value name very much, but don't have alternatives to suggest. Maybe you have?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I like this name very much and think it's descriptive of what is happening. Given a dictionary of parameters, it chooses a single value to use for a single parameter.

"""Get a single parameter value, accounting for aliases.

Parameters
----------
main_param_name : str
Name of the main parameter to get a value for. One of the keys of ``_ConfigAliases``.
params : dict
Dictionary of LightGBM parameters.
default_value : Any
Default value to use for the parameter, if none is found in ``params``.

Returns
-------
params : dict
A ``params`` dict with exactly one value for ``main_param_name``, and all aliases ``main_param_name`` removed.
If both ``main_param_name`` and one or more aliases for it are found, the value of ``main_param_name`` will be preferred.
"""
# avoid side effects on passed-in parameters
params = deepcopy(params)

# find a value, and remove other aliases with .pop()
# prefer the value of 'main_param_name' if it exists, otherwise search the aliases
found_value = None
if main_param_name in params.keys():
found_value = params[main_param_name]

for param in _ConfigAliases.get(main_param_name):
val = params.pop(param, None)
if found_value is None and val is not None:
found_value = val

if found_value is not None:
params[main_param_name] = found_value
else:
params[main_param_name] = default_value

return params


MAX_INT32 = (1 << 31) - 1

"""Macro definition of data type in C API of LightGBM"""
Expand Down Expand Up @@ -1052,7 +1093,7 @@ def __init__(self, data, label=None, reference=None,
self.silent = silent
self.feature_name = feature_name
self.categorical_feature = categorical_feature
self.params = copy.deepcopy(params)
self.params = deepcopy(params)
self.free_raw_data = free_raw_data
self.used_indices = None
self.need_slice = True
Expand Down Expand Up @@ -1510,13 +1551,13 @@ def save_binary(self, filename):
def _update_params(self, params):
if not params:
return self
params = copy.deepcopy(params)
params = deepcopy(params)

def update():
if not self.params:
self.params = params
else:
self.params_back_up = copy.deepcopy(self.params)
self.params_back_up = deepcopy(self.params)
self.params.update(params)

if self.handle is None:
Expand All @@ -1536,7 +1577,7 @@ def update():

def _reverse_update_params(self):
if self.handle is None:
self.params = copy.deepcopy(self.params_back_up)
self.params = deepcopy(self.params_back_up)
self.params_back_up = None
return self

Expand Down Expand Up @@ -2130,7 +2171,7 @@ def __init__(self, params=None, train_set=None, model_file=None, model_str=None,
self.__set_objective_to_none = False
self.best_iteration = -1
self.best_score = {}
params = {} if params is None else copy.deepcopy(params)
params = {} if params is None else deepcopy(params)
# user can set verbose with params, it has higher priority
if not any(verbose_alias in params for verbose_alias in _ConfigAliases.get("verbosity")) and silent:
params["verbose"] = -1
Expand All @@ -2139,22 +2180,40 @@ def __init__(self, params=None, train_set=None, model_file=None, model_str=None,
if not isinstance(train_set, Dataset):
raise TypeError('Training data should be Dataset instance, met {}'
.format(type(train_set).__name__))
# set network if necessary
for alias in _ConfigAliases.get("machines"):
if alias in params:
machines = params[alias]
if isinstance(machines, str):
num_machines = len(machines.split(','))
elif isinstance(machines, (list, set)):
num_machines = len(machines)
machines = ','.join(machines)
else:
raise ValueError("Invalid machines in params.")
self.set_network(machines,
local_listen_port=params.get("local_listen_port", 12400),
listen_time_out=params.get("listen_time_out", 120),
num_machines=params.setdefault("num_machines", num_machines))
break
params = _choose_param_value(
main_param_name="machines",
params=params,
default_value=None
)
# if "machines" is given, assume user wants to do distributed learning, and set up network
if params["machines"] is None:
params.pop("machines", None)
else:
machines = params["machines"]
if isinstance(machines, str):
num_machines_from_machine_list = len(machines.split(','))
elif isinstance(machines, (list, set)):
num_machines_from_machine_list = len(machines)
machines = ','.join(machines)
else:
raise ValueError("Invalid machines in params.")

params = _choose_param_value(
main_param_name="num_machines",
params=params,
default_value=num_machines_from_machine_list
)
params = _choose_param_value(
main_param_name="local_listen_port",
params=params,
default_value=12400
)
self.set_network(
machines=machines,
local_listen_port=params["local_listen_port"],
listen_time_out=params.get("time_out", 120),
num_machines=params["num_machines"]
)
# construct booster object
train_set.construct()
# copy the parameters from train_set
Expand Down Expand Up @@ -3056,7 +3115,7 @@ def predict(self, data, start_iteration=0, num_iteration=None,
Prediction result.
Can be sparse or a list of sparse objects (each element represents predictions for one class) for feature contributions (when ``pred_contrib=True``).
"""
predictor = self._to_predictor(copy.deepcopy(kwargs))
predictor = self._to_predictor(deepcopy(kwargs))
if num_iteration is None:
if start_iteration <= 0:
num_iteration = self.best_iteration
Expand Down Expand Up @@ -3090,14 +3149,14 @@ def refit(self, data, label, decay_rate=0.9, **kwargs):
"""
if self.__set_objective_to_none:
raise LightGBMError('Cannot refit due to null objective function.')
predictor = self._to_predictor(copy.deepcopy(kwargs))
predictor = self._to_predictor(deepcopy(kwargs))
leaf_preds = predictor.predict(data, -1, pred_leaf=True)
nrow, ncol = leaf_preds.shape
out_is_linear = ctypes.c_bool(False)
_safe_call(_LIB.LGBM_BoosterGetLinear(
self.handle,
ctypes.byref(out_is_linear)))
new_params = copy.deepcopy(self.params)
new_params = deepcopy(self.params)
new_params["linear_tree"] = out_is_linear.value
train_set = Dataset(data, label, silent=True, params=new_params)
new_params['refit_decay_rate'] = decay_rate
Expand Down
92 changes: 40 additions & 52 deletions python-package/lightgbm/dask.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from dask import delayed
from dask.distributed import Client, default_client, get_worker, wait

from .basic import _ConfigAliases, _LIB, _log_warning, _safe_call, LightGBMError
from .basic import _choose_param_value, _ConfigAliases, _LIB, _log_warning, _safe_call, LightGBMError
from .compat import DASK_INSTALLED, PANDAS_INSTALLED, SKLEARN_INSTALLED
from .sklearn import LGBMClassifier, LGBMRegressor, LGBMRanker

Expand Down Expand Up @@ -196,6 +196,44 @@ def _train(client, data, label, params, model_factory, sample_weight=None, group
"""
params = deepcopy(params)

params = _choose_param_value(
main_param_name="local_listen_port",
params=params,
default_value=12400
)

params = _choose_param_value(
main_param_name="tree_learner",
params=params,
default_value="data"
)
allowed_tree_learners = {
'data',
'data_parallel',
'feature',
'feature_parallel',
'voting',
'voting_parallel'
}
if params["tree_learner"] not in allowed_tree_learners:
_log_warning('Parameter tree_learner set to %s, which is not allowed. Using "data" as default' % tree_learner)
params['tree_learner'] = 'data'

if params['tree_learner'] not in {'data', 'data_parallel'}:
_log_warning(
'Support for tree_learner %s in lightgbm.dask is experimental and may break in a future release. \n'
'Use "data" for a stable, well-tested interface.' % params['tree_learner']
)

# Some passed-in parameters can be removed:
# * 'machines': constructed automatically from Dask worker list
# * 'machine_list_filename': not relevant for the Dask interface
# * 'num_machines': set automatically from Dask worker list
# * 'num_threads': overridden to match nthreads on each Dask process
for param_name in ['machines', 'machine_list_filename', 'num_machines', 'num_threads']:
for param_alias in _ConfigAliases.get(param_name):
params.pop(param_alias, None)

# Split arrays/dataframes into parts. Arrange parts into dicts to enforce co-locality
data_parts = _split_to_parts(data=data, is_matrix=True)
label_parts = _split_to_parts(data=label, is_matrix=False)
Expand Down Expand Up @@ -230,65 +268,15 @@ def _train(client, data, label, params, model_factory, sample_weight=None, group
master_worker = next(iter(worker_map))
worker_ncores = client.ncores()

tree_learner = None
for tree_learner_param in _ConfigAliases.get('tree_learner'):
tree_learner = params.get(tree_learner_param)
if tree_learner is not None:
params['tree_learner'] = tree_learner
break

allowed_tree_learners = {
'data',
'data_parallel',
'feature',
'feature_parallel',
'voting',
'voting_parallel'
}
if tree_learner is None:
_log_warning('Parameter tree_learner not set. Using "data" as default')
params['tree_learner'] = 'data'
elif tree_learner.lower() not in allowed_tree_learners:
_log_warning('Parameter tree_learner set to %s, which is not allowed. Using "data" as default' % tree_learner)
params['tree_learner'] = 'data'

if params['tree_learner'] not in {'data', 'data_parallel'}:
_log_warning(
'Support for tree_learner %s in lightgbm.dask is experimental and may break in a future release. Use "data" for a stable, well-tested interface.' % params['tree_learner']
)

local_listen_port = 12400
for port_param in _ConfigAliases.get('local_listen_port'):
val = params.get(port_param)
if val is not None:
local_listen_port = val
break

# find an open port on each worker. note that multiple workers can run
# on the same machine, so this needs to ensure that each one gets its
# own port
worker_address_to_port = _find_ports_for_workers(
client=client,
worker_addresses=worker_map.keys(),
local_listen_port=local_listen_port
local_listen_port=params["local_listen_port"]
)

# num_threads is set below, so remove it and all aliases of it from params
for num_thread_alias in _ConfigAliases.get('num_threads'):
params.pop(num_thread_alias, None)

# machines is constructed manually, so remove it and all aliases of it from params
for machine_alias in _ConfigAliases.get('machines'):
params.pop(machine_alias, None)

# machines is constructed manually, so remove machine_list_filename and all aliases of it from params
for machine_list_filename_alias in _ConfigAliases.get('machine_list_filename'):
params.pop(machine_list_filename_alias, None)

# machines is constructed manually, so remove num_machines and all aliases of it from params
for num_machine_alias in _ConfigAliases.get('num_machines'):
params.pop(num_machine_alias, None)

# Tell each worker to train on the parts that it has locally
futures_classifiers = [
client.submit(
Expand Down
46 changes: 46 additions & 0 deletions tests/python_package_test/test_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -329,3 +329,49 @@ def check_asserts(data):
lgb_data.set_init_score(sequence)
lgb_data.set_feature_name(feature_names)
check_asserts(lgb_data)


def test_choose_param_value():

original_params = {
"local_listen_port": 1234,
"port": 2222,
"metric": "auc",
"num_trees": 81
}

# should resolve duplicate aliases, and prefer the main parameter
params = lgb.basic._choose_param_value(
main_param_name="local_listen_port",
params=original_params,
default_value=5555
)
assert params["local_listen_port"] == 1234
assert "port" not in params

# should choose a value from an alias and set that value on main param
# if only an alias is used
params = lgb.basic._choose_param_value(
main_param_name="num_iterations",
params=params,
default_value=17
)
assert params["num_iterations"] == 81
assert "num_trees" not in params

# should use the default if main param and aliases are missing
params = lgb.basic._choose_param_value(
main_param_name="learning_rate",
params=params,
default_value=0.789
)
assert params["learning_rate"] == 0.789

# all changes should be made on copies and not modify the original
expected_params = {
"local_listen_port": 1234,
"port": 2222,
"metric": "auc",
"num_trees": 81
}
Comment on lines +371 to +376
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think expected_params can be gotten by copying original_params at the very beginning of the test to not re-type every value again here.

Suggested change
expected_params = {
"local_listen_port": 1234,
"port": 2222,
"metric": "auc",
"num_trees": 81
}
expected_params = copy(original_params)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

copy() performs a shallow clone, so some modifications to original_params can still impact the copy. https://docs.python.org/3/library/copy.html#copy.copy

I think it's preferable for this test to re-type the literal values, so that the test doesn't get a false positive because of subtle mistakes where we assumed something was a copy but actually was a reference

assert original_params == expected_params