diff --git a/python-package/lightgbm/dask.py b/python-package/lightgbm/dask.py index 8cba6de46fb3..b3571bbd786f 100644 --- a/python-package/lightgbm/dask.py +++ b/python-package/lightgbm/dask.py @@ -170,6 +170,44 @@ def _machines_to_worker_map(machines: str, worker_addresses: List[str]) -> Dict[ return out +def _possibly_fix_worker_map_duplicates(worker_map: Dict[str, int], client: Client) -> Dict[str, int]: + """Fix any duplicate IP-port pairs in a ``worker_map``.""" + worker_map = deepcopy(worker_map) + workers_that_need_new_ports = [] + host_to_port = defaultdict(set) + for worker, port in worker_map.items(): + host = urlparse(worker).hostname + if port in host_to_port[host]: + workers_that_need_new_ports.append(worker) + else: + host_to_port[host].add(port) + + # if any duplicates were found, search for new ports one by one + for worker in workers_that_need_new_ports: + _log_info(f"Searching for a LightGBM training port for worker '{worker}'") + host = urlparse(worker).hostname + retries_remaining = 100 + while retries_remaining > 0: + retries_remaining -= 1 + new_port = client.submit( + _find_random_open_port, + workers=[worker], + allow_other_workers=False, + pure=False + ).result() + if new_port not in host_to_port[host]: + worker_map[worker] = new_port + host_to_port[host].add(new_port) + break + + if retries_remaining == 0: + raise LightGBMError( + "Failed to find an open port. Try re-running training or explicitly setting 'machines' or 'local_listen_port'." + ) + + return worker_map + + def _train( client: Client, data: _DaskMatrixLike, @@ -367,10 +405,19 @@ def _train( } else: _log_info("Finding random open ports for workers") + # this approach with client.run() is faster than searching for ports + # serially, but can produce duplicates sometimes. Try the fast approach one + # time, then pass it through a function that will use a slower but more reliable + # approach if duplicates are found. worker_address_to_port = client.run( _find_random_open_port, workers=list(worker_addresses) ) + worker_address_to_port = _possibly_fix_worker_map_duplicates( + worker_map=worker_address_to_port, + client=client + ) + machines = ','.join([ '%s:%d' % (urlparse(worker_address).hostname, port) for worker_address, port diff --git a/tests/python_package_test/test_dask.py b/tests/python_package_test/test_dask.py index 4597a4b60e45..1ed7284ce305 100644 --- a/tests/python_package_test/test_dask.py +++ b/tests/python_package_test/test_dask.py @@ -392,6 +392,37 @@ def test_find_random_open_port(client): client.close(timeout=CLIENT_CLOSE_TIMEOUT) +def test_possibly_fix_worker_map(capsys, client): + client.wait_for_workers(2) + worker_addresses = list(client.scheduler_info()["workers"].keys()) + + retry_msg = 'Searching for a LightGBM training port for worker' + + # should handle worker maps without any duplicates + map_without_duplicates = { + worker_address: 12400 + i + for i, worker_address in enumerate(worker_addresses) + } + patched_map = lgb.dask._possibly_fix_worker_map_duplicates( + client=client, + worker_map=map_without_duplicates + ) + assert patched_map == map_without_duplicates + assert retry_msg not in capsys.readouterr().out + + # should handle worker maps with duplicates + map_with_duplicates = { + worker_address: 12400 + for i, worker_address in enumerate(worker_addresses) + } + patched_map = lgb.dask._possibly_fix_worker_map_duplicates( + client=client, + worker_map=map_with_duplicates + ) + assert retry_msg in capsys.readouterr().out + assert len(set(patched_map.values())) == len(worker_addresses) + + def test_training_does_not_fail_on_port_conflicts(client): _, _, _, _, dX, dy, dw, _ = _create_data('binary-classification', output='array')