From 93a96d4cd0f668019f54591eabc152127057b45d Mon Sep 17 00:00:00 2001 From: Tom White Date: Wed, 17 Jan 2024 17:09:07 +0000 Subject: [PATCH] Add a test to check that quad means computes the same result with full optimization --- cubed/tests/test_core.py | 36 ++++++++++++++++++++++++++++++++++++ 1 file changed, 36 insertions(+) diff --git a/cubed/tests/test_core.py b/cubed/tests/test_core.py index e352886a..94071ad5 100644 --- a/cubed/tests/test_core.py +++ b/cubed/tests/test_core.py @@ -1,4 +1,5 @@ import platform +import random import dill import numpy as np @@ -11,6 +12,7 @@ import cubed.random from cubed.backend_array_api import namespace as nxp from cubed.core.ops import merge_chunks, partial_reduce, tree_reduce +from cubed.core.optimization import fuse_all_optimize_dag from cubed.tests.utils import ( ALL_EXECUTORS, MAIN_EXECUTORS, @@ -533,3 +535,37 @@ def test_plan_quad_means(tmp_path, t_length): assert m.plan.num_tasks() > 0 m.visualize(filename=tmp_path / "quad_means") + + +def quad_means(tmp_path, t_length): + # based on sizes from https://gist.github.com/TomNicholas/c6a28f7c22c6981f75bce280d3e28283 + spec = cubed.Spec(tmp_path, allowed_mem="2GB", reserved_mem="100MB") + u = cubed.random.random((t_length, 1, 987, 1920), chunks=(10, 1, -1, -1), spec=spec) + v = cubed.random.random((t_length, 1, 987, 1920), chunks=(10, 1, -1, -1), spec=spec) + uv = u * v + m = xp.mean(uv, axis=0) + return m + + +def test_quad_means(tmp_path, t_length=50): + # run twice, with and without optimization + # set the random seed to ensure deterministic results + random.seed(42) + m0 = quad_means(tmp_path, t_length) + + random.seed(42) + m1 = quad_means(tmp_path, t_length) + + m1.visualize( + filename=tmp_path / "quad_means", optimize_function=fuse_all_optimize_dag + ) + + cubed.to_zarr(m0, store=tmp_path / "result0") + cubed.to_zarr( + m1, store=tmp_path / "result1", optimize_function=fuse_all_optimize_dag + ) + + res0 = zarr.open_array(tmp_path / "result0") + res1 = zarr.open_array(tmp_path / "result1") + + assert_array_equal(res0[:], res1[:])