Skip to content

Commit

Permalink
Add example for xarray groupby reduction causing memory pressure (#1528)
Browse files Browse the repository at this point in the history
Co-authored-by: Florian Jetter <[email protected]>
Co-authored-by: Hendrik Makait <[email protected]>
  • Loading branch information
3 people authored Aug 20, 2024
1 parent 9de9f3b commit d216c83
Show file tree
Hide file tree
Showing 6 changed files with 84 additions and 5 deletions.
1 change: 1 addition & 0 deletions AB_environments/AB_baseline.conda.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ dependencies:
- ipycytoscape ==1.3.3
- click ==8.1.7
- xarray ==2024.07.0
- flox ==0.9.9
- zarr ==2.18.2
- cftime ==1.6.4
- msgpack-python
Expand Down
1 change: 1 addition & 0 deletions AB_environments/AB_sample.conda.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ dependencies:
- ipycytoscape ==1.3.3
- click ==8.1.7
- xarray ==2024.07.0
- flox ==0.9.9
- zarr ==2.18.2
- cftime ==1.6.4
- msgpack-python
Expand Down
1 change: 1 addition & 0 deletions ci/environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ dependencies:
- ipycytoscape ==1.3.3
- click ==8.1.7
- xarray ==2024.07.0
- flox ==0.9.9
- zarr ==2.18.2
- cftime ==1.6.4
- msgpack-python
Expand Down
6 changes: 6 additions & 0 deletions cluster_kwargs.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,12 @@ spill_cluster:
worker_disk_size: 64
worker_vm_types: [m6i.large] # 2CPU, 8GiB

# For tests/benchmarks/test_xarray.py
group_reduction_cluster:
n_workers: 20
worker_vm_types: [m6i.xlarge] # 4CPU, 16GiB
region: "us-east-1" # Same region as dataset

# For tests/workflows/test_embarrassingly_parallel.py
embarrassingly_parallel:
n_workers: 100
Expand Down
11 changes: 6 additions & 5 deletions tests/benchmarks/test_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,10 +37,10 @@ def test_anom_mean(small_client, new_array):
dims=["time", "x"],
coords={"day": ("time", np.arange(data.shape[0]) % ngroups)},
)

clim = arr.groupby("day").mean(dim="time")
anom = arr.groupby("day") - clim
anom_mean = anom.mean(dim="time")
with xarray.set_options(use_flox=False):
clim = arr.groupby("day").mean(dim="time")
anom = arr.groupby("day") - clim
anom_mean = anom.mean(dim="time")

wait(anom_mean, small_client, 10 * 60)

Expand Down Expand Up @@ -136,7 +136,8 @@ def test_climatic_mean(small_client, new_array):
coords={"init_date": np.arange(data.shape[1]) % 10},
)
# arr_clim = array.groupby("init_date.month").mean(dim="init_date")
arr_clim = array.groupby("init_date").mean(dim="init_date")
with xarray.set_options(use_flox=False):
arr_clim = array.groupby("init_date").mean(dim="init_date")

wait(arr_clim, small_client, 15 * 60)

Expand Down
69 changes: 69 additions & 0 deletions tests/benchmarks/test_xarray.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
import uuid

import fsspec
import pytest
from coiled import Cluster
from distributed import Client

from tests.conftest import dump_cluster_kwargs
from tests.utils_test import wait

xr = pytest.importorskip("xarray")
pytest.importorskip("flox")


@pytest.fixture(scope="module")
def group_reduction_cluster(dask_env_variables, cluster_kwargs, github_cluster_tags):
kwargs = dict(
name=f"xarray-group-reduction-{uuid.uuid4().hex[:8]}",
environ=dask_env_variables,
tags=github_cluster_tags,
**cluster_kwargs["group_reduction_cluster"],
)
dump_cluster_kwargs(kwargs, "group_reduction_cluster")
with Cluster(**kwargs) as cluster:
yield cluster


@pytest.fixture
def group_reduction_client(
group_reduction_cluster, cluster_kwargs, upload_cluster_dump, benchmark_all
):
n_workers = cluster_kwargs["group_reduction_cluster"]["n_workers"]
with Client(group_reduction_cluster) as client:
group_reduction_cluster.scale(n_workers)
client.wait_for_workers(n_workers, timeout=600)
client.restart()
with upload_cluster_dump(client), benchmark_all(client):
yield client


@pytest.mark.parametrize(
"func",
[
pytest.param(
lambda x: x.groupby("time.month").mean(method="cohorts"), id="cohorts"
),
pytest.param(
lambda x: x.groupby("time.month").mean(method="map-reduce"), id="map-reduce"
),
pytest.param(
lambda x: x.chunk(time=xr.groupers.TimeResampler("ME"))
.groupby("time.month")
.mean(method="cohorts"),
id="chunked-cohorts",
),
],
)
def test_xarray_groupby_reduction(group_reduction_client, func):
ds = xr.open_zarr(
fsspec.get_mapper(
"s3://noaa-nwm-retrospective-2-1-zarr-pds/rtout.zarr", anon=True
),
consolidated=True,
)
# slice dataset properly to keep runtime in check
subset = ds.zwattablrt.sel(time=slice("2001", "2002"))
subset = subset.isel(x=slice(0, 350 * 8), y=slice(0, 350 * 8))
result = func(subset)
wait(result, group_reduction_client, 10 * 60)

0 comments on commit d216c83

Please sign in to comment.