diff --git a/dask_cuda/local_cuda_cluster.py b/dask_cuda/local_cuda_cluster.py index a82d59055..e84dce0c7 100644 --- a/dask_cuda/local_cuda_cluster.py +++ b/dask_cuda/local_cuda_cluster.py @@ -3,7 +3,6 @@ import os import warnings from functools import partial -from typing import Literal import dask from distributed import LocalCluster, Nanny, Worker @@ -23,13 +22,6 @@ ) -class IncreasedCloseTimeoutNanny(Nanny): - async def close( # type:ignore[override] - self, timeout: float = 10.0, reason: str = "nanny-close" - ) -> Literal["OK"]: - return await super().close(timeout=timeout, reason=reason) - - class LoggedWorker(Worker): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) @@ -39,7 +31,7 @@ async def start(self): self.data.set_address(self.address) -class LoggedNanny(IncreasedCloseTimeoutNanny): +class LoggedNanny(Nanny): def __init__(self, *args, **kwargs): super().__init__(*args, worker_class=LoggedWorker, **kwargs) @@ -341,7 +333,7 @@ def __init__( ) worker_class = partial( - LoggedNanny if log_spilling is True else IncreasedCloseTimeoutNanny, + LoggedNanny if log_spilling is True else Nanny, worker_class=worker_class, ) diff --git a/dask_cuda/utils_test.py b/dask_cuda/utils_test.py index 42582d7f8..aba77ee79 100644 --- a/dask_cuda/utils_test.py +++ b/dask_cuda/utils_test.py @@ -1,5 +1,7 @@ +from typing import Literal + import distributed -from distributed import Worker +from distributed import Nanny, Worker class MockWorker(Worker): @@ -21,3 +23,23 @@ def __del__(self): @staticmethod def device_get_count(): return 0 + + +class IncreasedCloseTimeoutNanny(Nanny): + """Increase `Nanny`'s close timeout. + + The internal close timeout mechanism of `Nanny` recomputes the time left to kill + the `Worker` process based on elapsed time of the close task, which may leave + very little time for the subprocess to shutdown cleanly, which may cause tests + to fail when the system is under higher load. This class increases the default + close timeout of 5.0 seconds that `Nanny` sets by default, which can be overriden + via Distributed's public API. + + This class can be used with the `worker_class` argument of `LocalCluster` or + `LocalCUDACluster` to provide a much higher default of 30.0 seconds. + """ + + async def close( # type:ignore[override] + self, timeout: float = 30.0, reason: str = "nanny-close" + ) -> Literal["OK"]: + return await super().close(timeout=timeout, reason=reason)