diff --git a/HISTORY.rst b/HISTORY.rst index 2180f57..48c12ec 100644 --- a/HISTORY.rst +++ b/HISTORY.rst @@ -2,7 +2,11 @@ History ======= -.. current developments +Latest +------ + +* Add :py:func:`rhg_compute_tools.kubernetes.wait` function to block until cluster reaches + the requested number of workers, optionally with a tqdm pbar. v0.2.3 ------ diff --git a/requirements_dev.txt b/requirements_dev.txt index bd4023d..0049682 100644 --- a/requirements_dev.txt +++ b/requirements_dev.txt @@ -9,3 +9,4 @@ pytest-cov pytest-runner pytest-mock twine +tqdm diff --git a/rhg_compute_tools/kubernetes.py b/rhg_compute_tools/kubernetes.py index 77f5f64..d08963c 100644 --- a/rhg_compute_tools/kubernetes.py +++ b/rhg_compute_tools/kubernetes.py @@ -2,6 +2,7 @@ """Tools for interacting with kubernetes.""" import os +import time import socket import traceback as tb import warnings @@ -341,3 +342,63 @@ def get_micro_cluster(*args, **kwargs): """ return get_cluster(*args, scaling_factor=(0.97 / 1.75), **kwargs) + + +def wait(cluster, min_workers=None, pbar=True, pbar_kwargs=None): + """ + Block execution while a cluster scales to the requested number of workers + + Note that this function does not currently work on dask gateway clusters + + Parameters + ---------- + cluster : dask_kubernetes.KubeCluster + Scalable dask cluster object with ``requested`` and ``scheduler`` + attributes. These attributes are used to determine how many + workers have been requested and how many are currently available. + min_workers : int, optional + Number of workers to wait for before returning. Default is the + total number of requested workers. This argument can be used to + set a minimum acceptable number, even if the total requested is + higher. This may be important for adaptive clusters, where the + number requested may change during the function's execution. + pbar : bool, optional + If true, displays a tqdm progress bar to track the worker's + spinup. Note that while this function is running, any dask + interactive widgets will not update, so the worker count on + the widget may be inaccurate. The progress bar displayed by + this function will reflect the actual worker count. + pbar_kwargs : dict, optional + Optional additional keyword arguments to pass to + :py:func:`tqdm.auto.tqdm` (default ``{}``). + """ + + if min_workers is None: + # don't race the cluster + time.sleep(0.5) + min_workers = len(cluster.requested) + + if pbar_kwargs is None: + pbar_kwargs = {} + + if pbar: + from tqdm.auto import tqdm + bar = tqdm(total=min_workers, **pbar_kwargs) + + while True: + num_workers = len(cluster.scheduler.workers) + + # check to see if the request has decreased, e.g. in an adaptive cluster + min_workers = min(min_workers, len(cluster.requested)) + + if pbar: + bar.n = num_workers + bar.refresh() + + if num_workers >= min_workers: + if pbar: + bar.close() + + break + + time.sleep(0.5)