diff --git a/python-package/lightgbm/basic.py b/python-package/lightgbm/basic.py index 5b0a73f4ed20..b7052243b0ee 100644 --- a/python-package/lightgbm/basic.py +++ b/python-package/lightgbm/basic.py @@ -298,6 +298,10 @@ class _ConfigAliases: "local_listen_port": {"local_listen_port", "local_port", "port"}, + "machine_list_filename": {"machine_list_filename", + "machine_list_file", + "machine_list", + "mlist"}, "machines": {"machines", "workers", "nodes"}, @@ -315,6 +319,8 @@ class _ConfigAliases: "num_rounds", "num_boost_round", "n_estimators"}, + "num_machines": {"num_machines", + "num_machine"}, "num_threads": {"num_threads", "num_thread", "nthread", diff --git a/python-package/lightgbm/dask.py b/python-package/lightgbm/dask.py index 14d349db961f..3fbb6183d9ee 100644 --- a/python-package/lightgbm/dask.py +++ b/python-package/lightgbm/dask.py @@ -230,7 +230,7 @@ def _train(client, data, label, params, model_factory, sample_weight=None, group return part # trigger error locally # Find locations of all parts and map them to particular Dask workers - key_to_part_dict = dict([(part.key, part) for part in parts]) + key_to_part_dict = {part.key: part for part in parts} who_has = client.who_has(parts) worker_map = defaultdict(list) for key, workers in who_has.items(): @@ -280,6 +280,18 @@ def _train(client, data, label, params, model_factory, sample_weight=None, group for num_thread_alias in _ConfigAliases.get('num_threads'): params.pop(num_thread_alias, None) + # machines is constructed manually, so remove it and all aliases of it from params + for machine_alias in _ConfigAliases.get('machines'): + params.pop(machine_alias, None) + + # machines is constructed manually, so remove machine_list_filename and all aliases of it from params + for machine_list_filename_alias in _ConfigAliases.get('machine_list_filename'): + params.pop(machine_list_filename_alias, None) + + # machines is constructed manually, so remove num_machines and all aliases of it from params + for num_machine_alias in _ConfigAliases.get('num_machines'): + params.pop(num_machine_alias, None) + # Tell each worker to train on the parts that it has locally futures_classifiers = [ client.submit(