Skip to content

Commit

Permalink
Cache occupancy in WorkStealing.balance() (#9005)
Browse files Browse the repository at this point in the history
  • Loading branch information
hendrikmakait authored Feb 5, 2025
1 parent 348082f commit e5711c6
Show file tree
Hide file tree
Showing 2 changed files with 53 additions and 35 deletions.
46 changes: 31 additions & 15 deletions distributed/stealing.py
Original file line number Diff line number Diff line change
Expand Up @@ -426,6 +426,10 @@ def balance(self) -> None:
log = []
start = time()

# Pre-calculate all occupancies once, they don't change during balancing
occupancies = {ws: ws.occupancy for ws in s.workers.values()}
combined_occupancy = partial(self._combined_occupancy, occupancies=occupancies)

i = 0
# Paused and closing workers must never become thieves
potential_thieves = set(s.idle.values())
Expand All @@ -434,21 +438,19 @@ def balance(self) -> None:
victim: WorkerState | None
potential_victims: set[WorkerState] | list[WorkerState] = s.saturated
if not potential_victims:
potential_victims = topk(
10, s.workers.values(), key=self._combined_occupancy
)
potential_victims = topk(10, s.workers.values(), key=combined_occupancy)
potential_victims = [
ws
for ws in potential_victims
if self._combined_occupancy(ws) > 0.2
if combined_occupancy(ws) > 0.2
and self._combined_nprocessing(ws) > ws.nthreads
and ws not in potential_thieves
]
if not potential_victims:
return
if len(potential_victims) < 20:
potential_victims = sorted(
potential_victims, key=self._combined_occupancy, reverse=True
potential_victims, key=combined_occupancy, reverse=True
)
assert potential_victims
assert potential_thieves
Expand All @@ -472,11 +474,15 @@ def balance(self) -> None:
stealable.discard(ts)
continue
i += 1
if not (thief := self._get_thief(s, ts, potential_thieves)):
if not (
thief := self._get_thief(
s, ts, potential_thieves, occupancies=occupancies
)
):
continue

occ_thief = self._combined_occupancy(thief)
occ_victim = self._combined_occupancy(victim)
occ_thief = combined_occupancy(thief)
occ_victim = combined_occupancy(victim)
comm_cost_thief = self.scheduler.get_comm_cost(ts, thief)
comm_cost_victim = self.scheduler.get_comm_cost(ts, victim)
compute = self.scheduler._get_prefix_duration(ts.prefix)
Expand All @@ -501,7 +507,7 @@ def balance(self) -> None:
self.metrics["request_count_total"][level] += 1
self.metrics["request_cost_total"][level] += cost

occ_thief = self._combined_occupancy(thief)
occ_thief = combined_occupancy(thief)
nproc_thief = self._combined_nprocessing(thief)

# FIXME: In the worst case, the victim may have 3x the amount of work
Expand All @@ -515,7 +521,7 @@ def balance(self) -> None:
# properly clean up, we would not need this
stealable.discard(ts)
self.scheduler.check_idle_saturated(
victim, occ=self._combined_occupancy(victim)
victim, occ=combined_occupancy(victim)
)

if log:
Expand All @@ -525,8 +531,10 @@ def balance(self) -> None:
if s.digests:
s.digests["steal-duration"].add(stop - start)

def _combined_occupancy(self, ws: WorkerState) -> float:
return ws.occupancy + self.in_flight_occupancy[ws]
def _combined_occupancy(
self, ws: WorkerState, *, occupancies: dict[WorkerState, float]
) -> float:
return occupancies[ws] + self.in_flight_occupancy[ws]

def _combined_nprocessing(self, ws: WorkerState) -> int:
return len(ws.processing) + self.in_flight_tasks[ws]
Expand All @@ -552,7 +560,9 @@ def story(self, *keys_or_ts: str | TaskState) -> list:
out.append(t)
return out

def stealing_objective(self, ts: TaskState, ws: WorkerState) -> tuple[float, ...]:
def stealing_objective(
self, ts: TaskState, ws: WorkerState, *, occupancies: dict[WorkerState, float]
) -> tuple[float, ...]:
"""Objective function to determine which worker should get the task
Minimize expected start time. If a tie then break with data storage.
Expand All @@ -567,7 +577,8 @@ def stealing_objective(self, ts: TaskState, ws: WorkerState) -> tuple[float, ...
Scheduler.worker_objective
"""
occupancy = self._combined_occupancy(
ws
ws,
occupancies=occupancies,
) / ws.nthreads + self.scheduler.get_comm_cost(ts, ws)
if ts.actor:
return (len(ws.actors), occupancy, ws.nbytes)
Expand All @@ -579,6 +590,8 @@ def _get_thief(
scheduler: SchedulerState,
ts: TaskState,
potential_thieves: set[WorkerState],
*,
occupancies: dict[WorkerState, float],
) -> WorkerState | None:
valid_workers = scheduler.valid_workers(ts)
if valid_workers is not None:
Expand All @@ -587,7 +600,10 @@ def _get_thief(
potential_thieves = valid_thieves
elif not ts.loose_restrictions:
return None
return min(potential_thieves, key=partial(self.stealing_objective, ts))
return min(
potential_thieves,
key=partial(self.stealing_objective, ts, occupancies=occupancies),
)


fast_tasks = {
Expand Down
42 changes: 22 additions & 20 deletions distributed/tests/test_steal.py
Original file line number Diff line number Diff line change
Expand Up @@ -1948,7 +1948,7 @@ async def test_trivial_workload_should_not_cause_work_stealing(c, s, *workers):
client=True,
config={"distributed.scheduler.worker-saturation": "inf"},
)
async def test_stealing_ogjective_accounts_for_in_flight(c, s, a):
async def test_stealing_objective_accounts_for_in_flight(c, s, a):
"""Regression test that work-stealing's objective correctly accounts for in-flight data requests"""
in_event = Event()
block_event = Event()
Expand All @@ -1973,32 +1973,34 @@ def block(i: int, in_event: Event, block_event: Event) -> int:
wsB = s.workers[b.address]
ts = next(iter(wsA.processing))

occupancies = {ws: ws.occupancy for ws in s.workers.values()}
# No in-flight requests, so both match
assert extension.stealing_objective(ts, wsA) == s.worker_objective(
ts, wsA
)
assert extension.stealing_objective(ts, wsB) == s.worker_objective(
ts, wsB
)
assert extension.stealing_objective(
ts, wsA, occupancies=occupancies
) == s.worker_objective(ts, wsA)
assert extension.stealing_objective(
ts, wsB, occupancies=occupancies
) == s.worker_objective(ts, wsB)

extension.balance()
assert extension.in_flight
# We move tasks from a to b
assert extension.stealing_objective(ts, wsA) < s.worker_objective(
ts, wsA
)
assert extension.stealing_objective(ts, wsB) > s.worker_objective(
ts, wsB
)
assert extension.stealing_objective(
ts, wsA, occupancies=occupancies
) < s.worker_objective(ts, wsA)
assert extension.stealing_objective(
ts, wsB, occupancies=occupancies
) > s.worker_objective(ts, wsB)

await async_poll_for(lambda: not extension.in_flight, timeout=5)
occupancies = {ws: ws.occupancy for ws in s.workers.values()}
# No in-flight requests, so both match
assert extension.stealing_objective(ts, wsA) == s.worker_objective(
ts, wsA
)
assert extension.stealing_objective(ts, wsB) == s.worker_objective(
ts, wsB
)
assert extension.stealing_objective(
ts, wsA, occupancies=occupancies
) == s.worker_objective(ts, wsA)
assert extension.stealing_objective(
ts, wsB, occupancies=occupancies
) == s.worker_objective(ts, wsB)
finally:
await block_event.set()
finally:
Expand Down Expand Up @@ -2031,7 +2033,7 @@ def block(i: int, in_event: Event, block_event: Event) -> int:
await in_event.wait()

# This is the pre-condition for the observed problem:
# There are tasks that execute fox a long time but do not have an average
# There are tasks that execute for a long time but do not have an average
s.task_prefixes["block"].add_exec_time(100)
assert s.task_prefixes["block"].duration_average == -1

Expand Down

0 comments on commit e5711c6

Please sign in to comment.