diff --git a/AB_environments/AB_baseline.conda.yaml b/AB_environments/AB_baseline.conda.yaml index 2aeb5ce2cc..cb0c031497 100644 --- a/AB_environments/AB_baseline.conda.yaml +++ b/AB_environments/AB_baseline.conda.yaml @@ -28,6 +28,7 @@ dependencies: - ipycytoscape ==1.3.3 - click ==8.1.7 - xarray ==2024.07.0 + - flox ==0.9.9 - zarr ==2.18.2 - cftime ==1.6.4 - msgpack-python diff --git a/AB_environments/AB_sample.conda.yaml b/AB_environments/AB_sample.conda.yaml index 77682e5c1c..b6bd5e3b3e 100644 --- a/AB_environments/AB_sample.conda.yaml +++ b/AB_environments/AB_sample.conda.yaml @@ -34,6 +34,7 @@ dependencies: - ipycytoscape ==1.3.3 - click ==8.1.7 - xarray ==2024.07.0 + - flox ==0.9.9 - zarr ==2.18.2 - cftime ==1.6.4 - msgpack-python diff --git a/ci/environment.yml b/ci/environment.yml index 8fafe24f32..7b2072ff46 100644 --- a/ci/environment.yml +++ b/ci/environment.yml @@ -30,6 +30,7 @@ dependencies: - ipycytoscape ==1.3.3 - click ==8.1.7 - xarray ==2024.07.0 + - flox ==0.9.9 - zarr ==2.18.2 - cftime ==1.6.4 - msgpack-python diff --git a/cluster_kwargs.yaml b/cluster_kwargs.yaml index c24a15c654..a66ac16e88 100644 --- a/cluster_kwargs.yaml +++ b/cluster_kwargs.yaml @@ -33,6 +33,12 @@ spill_cluster: worker_disk_size: 64 worker_vm_types: [m6i.large] # 2CPU, 8GiB +# For tests/benchmarks/test_xarray.py +group_reduction_cluster: + n_workers: 20 + worker_vm_types: [m6i.xlarge] # 4CPU, 16GiB + region: "us-east-1" # Same region as dataset + # For tests/workflows/test_embarrassingly_parallel.py embarrassingly_parallel: n_workers: 100 diff --git a/tests/benchmarks/test_array.py b/tests/benchmarks/test_array.py index 1b54b5d255..6b207aa27f 100644 --- a/tests/benchmarks/test_array.py +++ b/tests/benchmarks/test_array.py @@ -37,10 +37,10 @@ def test_anom_mean(small_client, new_array): dims=["time", "x"], coords={"day": ("time", np.arange(data.shape[0]) % ngroups)}, ) - - clim = arr.groupby("day").mean(dim="time") - anom = arr.groupby("day") - clim - anom_mean = anom.mean(dim="time") + with xarray.set_options(use_flox=False): + clim = arr.groupby("day").mean(dim="time") + anom = arr.groupby("day") - clim + anom_mean = anom.mean(dim="time") wait(anom_mean, small_client, 10 * 60) @@ -136,7 +136,8 @@ def test_climatic_mean(small_client, new_array): coords={"init_date": np.arange(data.shape[1]) % 10}, ) # arr_clim = array.groupby("init_date.month").mean(dim="init_date") - arr_clim = array.groupby("init_date").mean(dim="init_date") + with xarray.set_options(use_flox=False): + arr_clim = array.groupby("init_date").mean(dim="init_date") wait(arr_clim, small_client, 15 * 60) diff --git a/tests/benchmarks/test_xarray.py b/tests/benchmarks/test_xarray.py new file mode 100644 index 0000000000..d7ed6f705a --- /dev/null +++ b/tests/benchmarks/test_xarray.py @@ -0,0 +1,69 @@ +import uuid + +import fsspec +import pytest +from coiled import Cluster +from distributed import Client + +from tests.conftest import dump_cluster_kwargs +from tests.utils_test import wait + +xr = pytest.importorskip("xarray") +pytest.importorskip("flox") + + +@pytest.fixture(scope="module") +def group_reduction_cluster(dask_env_variables, cluster_kwargs, github_cluster_tags): + kwargs = dict( + name=f"xarray-group-reduction-{uuid.uuid4().hex[:8]}", + environ=dask_env_variables, + tags=github_cluster_tags, + **cluster_kwargs["group_reduction_cluster"], + ) + dump_cluster_kwargs(kwargs, "group_reduction_cluster") + with Cluster(**kwargs) as cluster: + yield cluster + + +@pytest.fixture +def group_reduction_client( + group_reduction_cluster, cluster_kwargs, upload_cluster_dump, benchmark_all +): + n_workers = cluster_kwargs["group_reduction_cluster"]["n_workers"] + with Client(group_reduction_cluster) as client: + group_reduction_cluster.scale(n_workers) + client.wait_for_workers(n_workers, timeout=600) + client.restart() + with upload_cluster_dump(client), benchmark_all(client): + yield client + + +@pytest.mark.parametrize( + "func", + [ + pytest.param( + lambda x: x.groupby("time.month").mean(method="cohorts"), id="cohorts" + ), + pytest.param( + lambda x: x.groupby("time.month").mean(method="map-reduce"), id="map-reduce" + ), + pytest.param( + lambda x: x.chunk(time=xr.groupers.TimeResampler("ME")) + .groupby("time.month") + .mean(method="cohorts"), + id="chunked-cohorts", + ), + ], +) +def test_xarray_groupby_reduction(group_reduction_client, func): + ds = xr.open_zarr( + fsspec.get_mapper( + "s3://noaa-nwm-retrospective-2-1-zarr-pds/rtout.zarr", anon=True + ), + consolidated=True, + ) + # slice dataset properly to keep runtime in check + subset = ds.zwattablrt.sel(time=slice("2001", "2002")) + subset = subset.isel(x=slice(0, 350 * 8), y=slice(0, 350 * 8)) + result = func(subset) + wait(result, group_reduction_client, 10 * 60)