From c068e21a97ba1bbf0a5474af7df721b94186c2af Mon Sep 17 00:00:00 2001 From: Peter Andreas Entschev Date: Thu, 12 Oct 2023 05:54:34 -0700 Subject: [PATCH] Increase close timeout of `Nanny` in `LocalCUDACluster` Tests in CI have been failing more often, but those errors can't be reproduced locally. This is possibly related to `Nanny`'s internal mechanism to establish timeouts to kill processes, perhaps due to higher load on the servers, tasks take longer and killing processes takes into account the overall time taken to establish a timeout, which is then drastically reduced leaving little time to actually shutdown processes. It is also not possible to programatically set a different timeout given existing Distributed's API, which currently calls `close()` without arguments in `SpecCluster._correct_state_internal()`. Given the limitations described above, a new class is added by this change with the sole purpose of rewriting the timeout for `Nanny.close()` method with an increased value, and then use the new class when launching `LocalCUDACluster` via the `worker_class` argument. --- dask_cuda/local_cuda_cluster.py | 22 ++++++++++++++-------- 1 file changed, 14 insertions(+), 8 deletions(-) diff --git a/dask_cuda/local_cuda_cluster.py b/dask_cuda/local_cuda_cluster.py index 324484331..ef15dcce3 100644 --- a/dask_cuda/local_cuda_cluster.py +++ b/dask_cuda/local_cuda_cluster.py @@ -2,6 +2,8 @@ import logging import os import warnings +from functools import partial +from typing import Literal import dask from distributed import LocalCluster, Nanny, Worker @@ -23,6 +25,13 @@ ) +class IncreasedCloseTimeoutNanny(Nanny): + async def close( # type:ignore[override] + self, timeout: float = 10.0, reason: str = "nanny-close" + ) -> Literal["OK"]: + return await super().close(timeout=timeout, reason=reason) + + class LoggedWorker(Worker): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) @@ -32,7 +41,7 @@ async def start(self): self.data.set_address(self.address) -class LoggedNanny(Nanny): +class LoggedNanny(IncreasedCloseTimeoutNanny): def __init__(self, *args, **kwargs): super().__init__(*args, worker_class=LoggedWorker, **kwargs) @@ -333,13 +342,10 @@ def __init__( enable_rdmacm=enable_rdmacm, ) - if worker_class is not None: - from functools import partial - - worker_class = partial( - LoggedNanny if log_spilling is True else Nanny, - worker_class=worker_class, - ) + worker_class = partial( + LoggedNanny if log_spilling is True else IncreasedCloseTimeoutNanny, + worker_class=worker_class, + ) self.pre_import = pre_import