Skip to content

Commit

Permalink
Fix inconsistent hashing for Nanny-spawned workers (#8400)
Browse files Browse the repository at this point in the history
  • Loading branch information
cisaacstern authored Dec 18, 2023
1 parent 53e95ec commit 8c3eb6f
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 1 deletion.
7 changes: 7 additions & 0 deletions distributed/nanny.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,6 +233,13 @@ def __init__( # type: ignore[no-untyped-def]
self.Worker = Worker if worker_class is None else worker_class

self.pre_spawn_env = _get_env_variables("distributed.nanny.pre-spawn-environ")
# To get consistent hashing on subprocesses, we need to set a consistent seed for
# the Python hash algorithm; xref https://github.com/dask/distributed/pull/8400
if self.pre_spawn_env.get("PYTHONHASHSEED") in (None, "0"):
# This number is arbitrary; it was chosen to commemorate
# https://github.com/dask/dask/issues/6640.
self.pre_spawn_env.update({"PYTHONHASHSEED": "6640"})

self.env = merge(
self.pre_spawn_env,
_get_env_variables("distributed.nanny.environ"),
Expand Down
13 changes: 13 additions & 0 deletions distributed/tests/test_dask_collections.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import dask.dataframe as dd

from distributed.client import wait
from distributed.nanny import Nanny
from distributed.utils_test import gen_cluster

dfs = [
Expand Down Expand Up @@ -124,6 +125,18 @@ async def test_bag_groupby_tasks_default(c, s, a, b):
assert not any("partd" in k[0] for k in b2.dask)


@gen_cluster(client=True, Worker=Nanny)
async def test_bag_groupby_key_hashing(c, s, a, b):
# https://github.com/dask/distributed/issues/4141
dsk = {("x", 0): (range, 5), ("x", 1): (range, 5), ("x", 2): (range, 5)}
grouped = db.Bag(dsk, "x", 3).groupby(lambda x: "even" if x % 2 == 0 else "odd")
remote = c.compute(grouped)
result = await remote
assert len(result) == 2
assert ("odd", [1, 3] * 3) in result
assert ("even", [0, 2, 4] * 3) in result


@pytest.mark.parametrize("wait", [wait, lambda x: None])
def test_dataframe_set_index_sync(wait, client):
df = dask.datasets.timeseries(
Expand Down
5 changes: 4 additions & 1 deletion distributed/tests/test_nanny.py
Original file line number Diff line number Diff line change
Expand Up @@ -340,10 +340,13 @@ async def test_environment_variable_config(c, s, monkeypatch):
},
)
async def test_environment_variable_pre_post_spawn(c, s, n):
assert n.env == {"PRE-SPAWN": "1", "POST-SPAWN": "2"}
assert n.env == {"PRE-SPAWN": "1", "POST-SPAWN": "2", "PYTHONHASHSEED": "6640"}
results = await c.run(lambda: os.environ)
assert results[n.worker_address]["PRE-SPAWN"] == "1"
assert results[n.worker_address]["POST-SPAWN"] == "2"
# if unset in pre-spawn-environ config, PYTHONHASHSEED defaults to "6640" to ensure
# consistent hashing across workers; https://github.com/dask/distributed/issues/4141
assert results[n.worker_address]["PYTHONHASHSEED"] == "6640"

del os.environ["PRE-SPAWN"]
assert "POST-SPAWN" not in os.environ
Expand Down

0 comments on commit 8c3eb6f

Please sign in to comment.