From e5711c6e24865d9d62eebbeb0776966dc344c523 Mon Sep 17 00:00:00 2001 From: Hendrik Makait Date: Wed, 5 Feb 2025 13:45:59 +0100 Subject: [PATCH] Cache occupancy in `WorkStealing.balance()` (#9005) --- distributed/stealing.py | 46 ++++++++++++++++++++++----------- distributed/tests/test_steal.py | 42 ++++++++++++++++-------------- 2 files changed, 53 insertions(+), 35 deletions(-) diff --git a/distributed/stealing.py b/distributed/stealing.py index 527deac687..e3c3ace81c 100644 --- a/distributed/stealing.py +++ b/distributed/stealing.py @@ -426,6 +426,10 @@ def balance(self) -> None: log = [] start = time() + # Pre-calculate all occupancies once, they don't change during balancing + occupancies = {ws: ws.occupancy for ws in s.workers.values()} + combined_occupancy = partial(self._combined_occupancy, occupancies=occupancies) + i = 0 # Paused and closing workers must never become thieves potential_thieves = set(s.idle.values()) @@ -434,13 +438,11 @@ def balance(self) -> None: victim: WorkerState | None potential_victims: set[WorkerState] | list[WorkerState] = s.saturated if not potential_victims: - potential_victims = topk( - 10, s.workers.values(), key=self._combined_occupancy - ) + potential_victims = topk(10, s.workers.values(), key=combined_occupancy) potential_victims = [ ws for ws in potential_victims - if self._combined_occupancy(ws) > 0.2 + if combined_occupancy(ws) > 0.2 and self._combined_nprocessing(ws) > ws.nthreads and ws not in potential_thieves ] @@ -448,7 +450,7 @@ def balance(self) -> None: return if len(potential_victims) < 20: potential_victims = sorted( - potential_victims, key=self._combined_occupancy, reverse=True + potential_victims, key=combined_occupancy, reverse=True ) assert potential_victims assert potential_thieves @@ -472,11 +474,15 @@ def balance(self) -> None: stealable.discard(ts) continue i += 1 - if not (thief := self._get_thief(s, ts, potential_thieves)): + if not ( + thief := self._get_thief( + s, ts, potential_thieves, occupancies=occupancies + ) + ): continue - occ_thief = self._combined_occupancy(thief) - occ_victim = self._combined_occupancy(victim) + occ_thief = combined_occupancy(thief) + occ_victim = combined_occupancy(victim) comm_cost_thief = self.scheduler.get_comm_cost(ts, thief) comm_cost_victim = self.scheduler.get_comm_cost(ts, victim) compute = self.scheduler._get_prefix_duration(ts.prefix) @@ -501,7 +507,7 @@ def balance(self) -> None: self.metrics["request_count_total"][level] += 1 self.metrics["request_cost_total"][level] += cost - occ_thief = self._combined_occupancy(thief) + occ_thief = combined_occupancy(thief) nproc_thief = self._combined_nprocessing(thief) # FIXME: In the worst case, the victim may have 3x the amount of work @@ -515,7 +521,7 @@ def balance(self) -> None: # properly clean up, we would not need this stealable.discard(ts) self.scheduler.check_idle_saturated( - victim, occ=self._combined_occupancy(victim) + victim, occ=combined_occupancy(victim) ) if log: @@ -525,8 +531,10 @@ def balance(self) -> None: if s.digests: s.digests["steal-duration"].add(stop - start) - def _combined_occupancy(self, ws: WorkerState) -> float: - return ws.occupancy + self.in_flight_occupancy[ws] + def _combined_occupancy( + self, ws: WorkerState, *, occupancies: dict[WorkerState, float] + ) -> float: + return occupancies[ws] + self.in_flight_occupancy[ws] def _combined_nprocessing(self, ws: WorkerState) -> int: return len(ws.processing) + self.in_flight_tasks[ws] @@ -552,7 +560,9 @@ def story(self, *keys_or_ts: str | TaskState) -> list: out.append(t) return out - def stealing_objective(self, ts: TaskState, ws: WorkerState) -> tuple[float, ...]: + def stealing_objective( + self, ts: TaskState, ws: WorkerState, *, occupancies: dict[WorkerState, float] + ) -> tuple[float, ...]: """Objective function to determine which worker should get the task Minimize expected start time. If a tie then break with data storage. @@ -567,7 +577,8 @@ def stealing_objective(self, ts: TaskState, ws: WorkerState) -> tuple[float, ... Scheduler.worker_objective """ occupancy = self._combined_occupancy( - ws + ws, + occupancies=occupancies, ) / ws.nthreads + self.scheduler.get_comm_cost(ts, ws) if ts.actor: return (len(ws.actors), occupancy, ws.nbytes) @@ -579,6 +590,8 @@ def _get_thief( scheduler: SchedulerState, ts: TaskState, potential_thieves: set[WorkerState], + *, + occupancies: dict[WorkerState, float], ) -> WorkerState | None: valid_workers = scheduler.valid_workers(ts) if valid_workers is not None: @@ -587,7 +600,10 @@ def _get_thief( potential_thieves = valid_thieves elif not ts.loose_restrictions: return None - return min(potential_thieves, key=partial(self.stealing_objective, ts)) + return min( + potential_thieves, + key=partial(self.stealing_objective, ts, occupancies=occupancies), + ) fast_tasks = { diff --git a/distributed/tests/test_steal.py b/distributed/tests/test_steal.py index 57bff3bd79..72aa67bfbe 100644 --- a/distributed/tests/test_steal.py +++ b/distributed/tests/test_steal.py @@ -1948,7 +1948,7 @@ async def test_trivial_workload_should_not_cause_work_stealing(c, s, *workers): client=True, config={"distributed.scheduler.worker-saturation": "inf"}, ) -async def test_stealing_ogjective_accounts_for_in_flight(c, s, a): +async def test_stealing_objective_accounts_for_in_flight(c, s, a): """Regression test that work-stealing's objective correctly accounts for in-flight data requests""" in_event = Event() block_event = Event() @@ -1973,32 +1973,34 @@ def block(i: int, in_event: Event, block_event: Event) -> int: wsB = s.workers[b.address] ts = next(iter(wsA.processing)) + occupancies = {ws: ws.occupancy for ws in s.workers.values()} # No in-flight requests, so both match - assert extension.stealing_objective(ts, wsA) == s.worker_objective( - ts, wsA - ) - assert extension.stealing_objective(ts, wsB) == s.worker_objective( - ts, wsB - ) + assert extension.stealing_objective( + ts, wsA, occupancies=occupancies + ) == s.worker_objective(ts, wsA) + assert extension.stealing_objective( + ts, wsB, occupancies=occupancies + ) == s.worker_objective(ts, wsB) extension.balance() assert extension.in_flight # We move tasks from a to b - assert extension.stealing_objective(ts, wsA) < s.worker_objective( - ts, wsA - ) - assert extension.stealing_objective(ts, wsB) > s.worker_objective( - ts, wsB - ) + assert extension.stealing_objective( + ts, wsA, occupancies=occupancies + ) < s.worker_objective(ts, wsA) + assert extension.stealing_objective( + ts, wsB, occupancies=occupancies + ) > s.worker_objective(ts, wsB) await async_poll_for(lambda: not extension.in_flight, timeout=5) + occupancies = {ws: ws.occupancy for ws in s.workers.values()} # No in-flight requests, so both match - assert extension.stealing_objective(ts, wsA) == s.worker_objective( - ts, wsA - ) - assert extension.stealing_objective(ts, wsB) == s.worker_objective( - ts, wsB - ) + assert extension.stealing_objective( + ts, wsA, occupancies=occupancies + ) == s.worker_objective(ts, wsA) + assert extension.stealing_objective( + ts, wsB, occupancies=occupancies + ) == s.worker_objective(ts, wsB) finally: await block_event.set() finally: @@ -2031,7 +2033,7 @@ def block(i: int, in_event: Event, block_event: Event) -> int: await in_event.wait() # This is the pre-condition for the observed problem: - # There are tasks that execute fox a long time but do not have an average + # There are tasks that execute for a long time but do not have an average s.task_prefixes["block"].add_exec_time(100) assert s.task_prefixes["block"].duration_average == -1