Skip to content

Commit

Permalink
Move IncreasedCloseTimeoutNanny to utils_test
Browse files Browse the repository at this point in the history
Move `IncreasedCloseTimeoutNanny` to `utils_test` and do not use it by
default in `LocalCUDACluster`. Tests should specify
`worker_class=IncreasedCloseTimeoutNanny` where appropriate.
  • Loading branch information
pentschev committed Oct 23, 2023
1 parent 723fa7e commit 21a57cd
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 11 deletions.
12 changes: 2 additions & 10 deletions dask_cuda/local_cuda_cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
import os
import warnings
from functools import partial
from typing import Literal

import dask
from distributed import LocalCluster, Nanny, Worker
Expand All @@ -23,13 +22,6 @@
)


class IncreasedCloseTimeoutNanny(Nanny):
async def close( # type:ignore[override]
self, timeout: float = 10.0, reason: str = "nanny-close"
) -> Literal["OK"]:
return await super().close(timeout=timeout, reason=reason)


class LoggedWorker(Worker):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
Expand All @@ -39,7 +31,7 @@ async def start(self):
self.data.set_address(self.address)


class LoggedNanny(IncreasedCloseTimeoutNanny):
class LoggedNanny(Nanny):
def __init__(self, *args, **kwargs):
super().__init__(*args, worker_class=LoggedWorker, **kwargs)

Expand Down Expand Up @@ -341,7 +333,7 @@ def __init__(
)

worker_class = partial(
LoggedNanny if log_spilling is True else IncreasedCloseTimeoutNanny,
LoggedNanny if log_spilling is True else Nanny,
worker_class=worker_class,
)

Expand Down
24 changes: 23 additions & 1 deletion dask_cuda/utils_test.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from typing import Literal

import distributed
from distributed import Worker
from distributed import Nanny, Worker


class MockWorker(Worker):
Expand All @@ -21,3 +23,23 @@ def __del__(self):
@staticmethod
def device_get_count():
return 0


class IncreasedCloseTimeoutNanny(Nanny):
"""Increase `Nanny`'s close timeout.
The internal close timeout mechanism of `Nanny` recomputes the time left to kill
the `Worker` process based on elapsed time of the close task, which may leave
very little time for the subprocess to shutdown cleanly, which may cause tests
to fail when the system is under higher load. This class increases the default
close timeout of 5.0 seconds that `Nanny` sets by default, which can be overriden
via Distributed's public API.
This class can be used with the `worker_class` argument of `LocalCluster` or
`LocalCUDACluster` to provide a much higher default of 30.0 seconds.
"""

async def close( # type:ignore[override]
self, timeout: float = 30.0, reason: str = "nanny-close"
) -> Literal["OK"]:
return await super().close(timeout=timeout, reason=reason)

0 comments on commit 21a57cd

Please sign in to comment.