diff --git a/python-package/lightgbm/dask.py b/python-package/lightgbm/dask.py index 4acbf10702d7..cfa5f79d9bbd 100644 --- a/python-package/lightgbm/dask.py +++ b/python-package/lightgbm/dask.py @@ -244,6 +244,7 @@ def _train(client, data, label, params, model_factory, sample_weight=None, group for tree_learner_param in _ConfigAliases.get('tree_learner'): tree_learner = params.get(tree_learner_param) if tree_learner is not None: + params['tree_learner'] = tree_learner break allowed_tree_learners = { @@ -261,6 +262,11 @@ def _train(client, data, label, params, model_factory, sample_weight=None, group _log_warning('Parameter tree_learner set to %s, which is not allowed. Using "data" as default' % tree_learner) params['tree_learner'] = 'data' + if params['tree_learner'] not in {'data', 'data_parallel'}: + _log_warning( + 'Support for tree_learner %s in lightgbm.dask is experimental and may break in a future release. Use "data" for a stable, well-tested interface.' % params['tree_learner'] + ) + local_listen_port = 12400 for port_param in _ConfigAliases.get('local_listen_port'): val = params.get(port_param)