From 4b1805359e100a6842cc045b143a974594b93fdf Mon Sep 17 00:00:00 2001 From: Tom White Date: Mon, 18 Nov 2024 09:46:15 +0000 Subject: [PATCH] Change default `max_total_num_input_blocks` to 10 (#615) To allow operations to be fused as long as the total number of input blocks does not exceed this number. Previously, the default was None, which meant operations were fused only if they had the same number of tasks. --- cubed/core/optimization.py | 15 ++++++----- cubed/tests/test_optimization.py | 44 ++++++++++++++++++++++++-------- docs/user-guide/optimization.md | 4 +-- 3 files changed, 44 insertions(+), 19 deletions(-) diff --git a/cubed/core/optimization.py b/cubed/core/optimization.py index 346d8e0c..5ec5064c 100644 --- a/cubed/core/optimization.py +++ b/cubed/core/optimization.py @@ -11,6 +11,9 @@ logger = logging.getLogger(__name__) +DEFAULT_MAX_TOTAL_SOURCE_ARRAYS = 4 +DEFAULT_MAX_TOTAL_NUM_INPUT_BLOCKS = 10 + def simple_optimize_dag(dag, array_names=None): """Apply map blocks fusion.""" @@ -154,8 +157,8 @@ def can_fuse_predecessors( name, *, array_names=None, - max_total_source_arrays=4, - max_total_num_input_blocks=None, + max_total_source_arrays=DEFAULT_MAX_TOTAL_SOURCE_ARRAYS, + max_total_num_input_blocks=DEFAULT_MAX_TOTAL_NUM_INPUT_BLOCKS, always_fuse=None, never_fuse=None, ): @@ -242,8 +245,8 @@ def fuse_predecessors( name, *, array_names=None, - max_total_source_arrays=4, - max_total_num_input_blocks=None, + max_total_source_arrays=DEFAULT_MAX_TOTAL_SOURCE_ARRAYS, + max_total_num_input_blocks=DEFAULT_MAX_TOTAL_NUM_INPUT_BLOCKS, always_fuse=None, never_fuse=None, ): @@ -297,8 +300,8 @@ def multiple_inputs_optimize_dag( dag, *, array_names=None, - max_total_source_arrays=4, - max_total_num_input_blocks=None, + max_total_source_arrays=DEFAULT_MAX_TOTAL_SOURCE_ARRAYS, + max_total_num_input_blocks=DEFAULT_MAX_TOTAL_NUM_INPUT_BLOCKS, always_fuse=None, never_fuse=None, ): diff --git a/cubed/tests/test_optimization.py b/cubed/tests/test_optimization.py index 6d635f0c..ebb94022 100644 --- a/cubed/tests/test_optimization.py +++ b/cubed/tests/test_optimization.py @@ -11,6 +11,8 @@ from cubed.backend_array_api import namespace as nxp from cubed.core.ops import elemwise, merge_chunks, partial_reduce from cubed.core.optimization import ( + DEFAULT_MAX_TOTAL_NUM_INPUT_BLOCKS, + DEFAULT_MAX_TOTAL_SOURCE_ARRAYS, fuse_all_optimize_dag, fuse_only_optimize_dag, fuse_predecessors, @@ -223,7 +225,11 @@ def fuse_one_level(arr, *, always_fuse=None): ) -def fuse_multiple_levels(*, max_total_source_arrays=4, max_total_num_input_blocks=None): +def fuse_multiple_levels( + *, + max_total_source_arrays=DEFAULT_MAX_TOTAL_SOURCE_ARRAYS, + max_total_num_input_blocks=DEFAULT_MAX_TOTAL_NUM_INPUT_BLOCKS, +): # use multiple_inputs_optimize_dag to test multiple levels of fusion return partial( multiple_inputs_optimize_dag, @@ -899,8 +905,7 @@ def test_fuse_merge_chunks_unary(spec): b = xp.negative(a) c = merge_chunks(b, chunks=(3, 2)) - # specify max_total_num_input_blocks to force c to fuse - opt_fn = fuse_multiple_levels(max_total_num_input_blocks=3) + opt_fn = fuse_multiple_levels() c.visualize(optimize_function=opt_fn) @@ -921,6 +926,16 @@ def test_fuse_merge_chunks_unary(spec): result = c.compute(optimize_function=opt_fn) assert_array_equal(result, -np.ones((3, 2))) + # now set max_total_num_input_blocks=None which means + # "only fuse if ops have same number of tasks", which they don't here + opt_fn = fuse_multiple_levels(max_total_num_input_blocks=None) + optimized_dag = c.plan.optimize(optimize_function=opt_fn).dag + + # merge_chunks uses a hidden op and array for block ids - ignore when comparing structure + assert not structurally_equivalent( + optimized_dag, expected_fused_dag, remove_hidden=True + ) + # merge chunks with different number of tasks (c has more tasks than d) # @@ -936,8 +951,7 @@ def test_fuse_merge_chunks_binary(spec): c = xp.add(a, b) d = merge_chunks(c, chunks=(3, 2)) - # specify max_total_num_input_blocks to force d to fuse - opt_fn = fuse_multiple_levels(max_total_num_input_blocks=6) + opt_fn = fuse_multiple_levels() d.visualize(optimize_function=opt_fn) @@ -963,6 +977,16 @@ def test_fuse_merge_chunks_binary(spec): result = d.compute(optimize_function=opt_fn) assert_array_equal(result, 2 * np.ones((3, 2))) + # now set max_total_num_input_blocks=None which means + # "only fuse if ops have same number of tasks", which they don't here + opt_fn = fuse_multiple_levels(max_total_num_input_blocks=None) + optimized_dag = d.plan.optimize(optimize_function=opt_fn).dag + + # merge_chunks uses a hidden op and array for block ids - ignore when comparing structure + assert not structurally_equivalent( + optimized_dag, expected_fused_dag, remove_hidden=True + ) + # like test_fuse_merge_chunks_unary, except uses partial_reduce def test_fuse_partial_reduce_unary(spec): @@ -970,8 +994,7 @@ def test_fuse_partial_reduce_unary(spec): b = xp.negative(a) c = partial_reduce(b, np.sum, split_every={0: 3}) - # specify max_total_num_input_blocks to force c to fuse - opt_fn = fuse_multiple_levels(max_total_num_input_blocks=3) + opt_fn = fuse_multiple_levels() c.visualize(optimize_function=opt_fn) @@ -996,8 +1019,7 @@ def test_fuse_partial_reduce_binary(spec): c = xp.add(a, b) d = partial_reduce(c, np.sum, split_every={0: 3}) - # specify max_total_num_input_blocks to force d to fuse - opt_fn = fuse_multiple_levels(max_total_num_input_blocks=6) + opt_fn = fuse_multiple_levels() d.visualize(optimize_function=opt_fn) @@ -1176,7 +1198,7 @@ def test_optimize_stack(spec): c = xp.stack((a, b), axis=0) d = c + 1 # try to fuse all ops into one (d will fuse with c, but c won't fuse with a and b) - d.compute(optimize_function=fuse_multiple_levels(max_total_num_input_blocks=10)) + d.compute(optimize_function=fuse_multiple_levels()) def test_optimize_concat(spec): @@ -1186,4 +1208,4 @@ def test_optimize_concat(spec): c = xp.concat((a, b), axis=0) d = c + 1 # try to fuse all ops into one (d will fuse with c, but c won't fuse with a and b) - d.compute(optimize_function=fuse_multiple_levels(max_total_num_input_blocks=10)) + d.compute(optimize_function=fuse_multiple_levels()) diff --git a/docs/user-guide/optimization.md b/docs/user-guide/optimization.md index b8e3a458..ec556546 100644 --- a/docs/user-guide/optimization.md +++ b/docs/user-guide/optimization.md @@ -112,9 +112,9 @@ e.visualize(optimize_function=opt_fn) The `max_total_num_input_blocks` argument to `multiple_inputs_optimize_dag` specifies the maximum number of input blocks (chunks) that are allowed in the fused operation. -Again, this is to limit the number of reads that an individual task must perform. The default is `None`, which means that operations are fused only if they have the same number of tasks. If set to an integer, then this limitation is removed, and tasks with a different number of tasks will be fused - as long as the total number of input blocks does not exceed the maximum. This setting is useful for reductions, and can be set using `functools.partial`: +Again, this is to limit the number of reads that an individual task must perform. If set to `None`, operations are fused only if they have the same number of tasks. If set to an integer (the default is 10), then tasks with a different number of tasks will be fused - as long as the total number of input blocks does not exceed the maximum. This setting is useful for reductions, and can be changed using `functools.partial`: ```python -opt_fn = partial(multiple_inputs_optimize_dag, max_total_num_input_blocks=10) +opt_fn = partial(multiple_inputs_optimize_dag, max_total_num_input_blocks=20) e.visualize(optimize_function=opt_fn) ```