diff --git a/cubed/core/optimization.py b/cubed/core/optimization.py index 346d8e0c..5ec5064c 100644 --- a/cubed/core/optimization.py +++ b/cubed/core/optimization.py @@ -11,6 +11,9 @@ logger = logging.getLogger(__name__) +DEFAULT_MAX_TOTAL_SOURCE_ARRAYS = 4 +DEFAULT_MAX_TOTAL_NUM_INPUT_BLOCKS = 10 + def simple_optimize_dag(dag, array_names=None): """Apply map blocks fusion.""" @@ -154,8 +157,8 @@ def can_fuse_predecessors( name, *, array_names=None, - max_total_source_arrays=4, - max_total_num_input_blocks=None, + max_total_source_arrays=DEFAULT_MAX_TOTAL_SOURCE_ARRAYS, + max_total_num_input_blocks=DEFAULT_MAX_TOTAL_NUM_INPUT_BLOCKS, always_fuse=None, never_fuse=None, ): @@ -242,8 +245,8 @@ def fuse_predecessors( name, *, array_names=None, - max_total_source_arrays=4, - max_total_num_input_blocks=None, + max_total_source_arrays=DEFAULT_MAX_TOTAL_SOURCE_ARRAYS, + max_total_num_input_blocks=DEFAULT_MAX_TOTAL_NUM_INPUT_BLOCKS, always_fuse=None, never_fuse=None, ): @@ -297,8 +300,8 @@ def multiple_inputs_optimize_dag( dag, *, array_names=None, - max_total_source_arrays=4, - max_total_num_input_blocks=None, + max_total_source_arrays=DEFAULT_MAX_TOTAL_SOURCE_ARRAYS, + max_total_num_input_blocks=DEFAULT_MAX_TOTAL_NUM_INPUT_BLOCKS, always_fuse=None, never_fuse=None, ): diff --git a/cubed/tests/test_optimization.py b/cubed/tests/test_optimization.py index 6d635f0c..ebb94022 100644 --- a/cubed/tests/test_optimization.py +++ b/cubed/tests/test_optimization.py @@ -11,6 +11,8 @@ from cubed.backend_array_api import namespace as nxp from cubed.core.ops import elemwise, merge_chunks, partial_reduce from cubed.core.optimization import ( + DEFAULT_MAX_TOTAL_NUM_INPUT_BLOCKS, + DEFAULT_MAX_TOTAL_SOURCE_ARRAYS, fuse_all_optimize_dag, fuse_only_optimize_dag, fuse_predecessors, @@ -223,7 +225,11 @@ def fuse_one_level(arr, *, always_fuse=None): ) -def fuse_multiple_levels(*, max_total_source_arrays=4, max_total_num_input_blocks=None): +def fuse_multiple_levels( + *, + max_total_source_arrays=DEFAULT_MAX_TOTAL_SOURCE_ARRAYS, + max_total_num_input_blocks=DEFAULT_MAX_TOTAL_NUM_INPUT_BLOCKS, +): # use multiple_inputs_optimize_dag to test multiple levels of fusion return partial( multiple_inputs_optimize_dag, @@ -899,8 +905,7 @@ def test_fuse_merge_chunks_unary(spec): b = xp.negative(a) c = merge_chunks(b, chunks=(3, 2)) - # specify max_total_num_input_blocks to force c to fuse - opt_fn = fuse_multiple_levels(max_total_num_input_blocks=3) + opt_fn = fuse_multiple_levels() c.visualize(optimize_function=opt_fn) @@ -921,6 +926,16 @@ def test_fuse_merge_chunks_unary(spec): result = c.compute(optimize_function=opt_fn) assert_array_equal(result, -np.ones((3, 2))) + # now set max_total_num_input_blocks=None which means + # "only fuse if ops have same number of tasks", which they don't here + opt_fn = fuse_multiple_levels(max_total_num_input_blocks=None) + optimized_dag = c.plan.optimize(optimize_function=opt_fn).dag + + # merge_chunks uses a hidden op and array for block ids - ignore when comparing structure + assert not structurally_equivalent( + optimized_dag, expected_fused_dag, remove_hidden=True + ) + # merge chunks with different number of tasks (c has more tasks than d) # @@ -936,8 +951,7 @@ def test_fuse_merge_chunks_binary(spec): c = xp.add(a, b) d = merge_chunks(c, chunks=(3, 2)) - # specify max_total_num_input_blocks to force d to fuse - opt_fn = fuse_multiple_levels(max_total_num_input_blocks=6) + opt_fn = fuse_multiple_levels() d.visualize(optimize_function=opt_fn) @@ -963,6 +977,16 @@ def test_fuse_merge_chunks_binary(spec): result = d.compute(optimize_function=opt_fn) assert_array_equal(result, 2 * np.ones((3, 2))) + # now set max_total_num_input_blocks=None which means + # "only fuse if ops have same number of tasks", which they don't here + opt_fn = fuse_multiple_levels(max_total_num_input_blocks=None) + optimized_dag = d.plan.optimize(optimize_function=opt_fn).dag + + # merge_chunks uses a hidden op and array for block ids - ignore when comparing structure + assert not structurally_equivalent( + optimized_dag, expected_fused_dag, remove_hidden=True + ) + # like test_fuse_merge_chunks_unary, except uses partial_reduce def test_fuse_partial_reduce_unary(spec): @@ -970,8 +994,7 @@ def test_fuse_partial_reduce_unary(spec): b = xp.negative(a) c = partial_reduce(b, np.sum, split_every={0: 3}) - # specify max_total_num_input_blocks to force c to fuse - opt_fn = fuse_multiple_levels(max_total_num_input_blocks=3) + opt_fn = fuse_multiple_levels() c.visualize(optimize_function=opt_fn) @@ -996,8 +1019,7 @@ def test_fuse_partial_reduce_binary(spec): c = xp.add(a, b) d = partial_reduce(c, np.sum, split_every={0: 3}) - # specify max_total_num_input_blocks to force d to fuse - opt_fn = fuse_multiple_levels(max_total_num_input_blocks=6) + opt_fn = fuse_multiple_levels() d.visualize(optimize_function=opt_fn) @@ -1176,7 +1198,7 @@ def test_optimize_stack(spec): c = xp.stack((a, b), axis=0) d = c + 1 # try to fuse all ops into one (d will fuse with c, but c won't fuse with a and b) - d.compute(optimize_function=fuse_multiple_levels(max_total_num_input_blocks=10)) + d.compute(optimize_function=fuse_multiple_levels()) def test_optimize_concat(spec): @@ -1186,4 +1208,4 @@ def test_optimize_concat(spec): c = xp.concat((a, b), axis=0) d = c + 1 # try to fuse all ops into one (d will fuse with c, but c won't fuse with a and b) - d.compute(optimize_function=fuse_multiple_levels(max_total_num_input_blocks=10)) + d.compute(optimize_function=fuse_multiple_levels()) diff --git a/docs/user-guide/optimization.md b/docs/user-guide/optimization.md index b8e3a458..ec556546 100644 --- a/docs/user-guide/optimization.md +++ b/docs/user-guide/optimization.md @@ -112,9 +112,9 @@ e.visualize(optimize_function=opt_fn) The `max_total_num_input_blocks` argument to `multiple_inputs_optimize_dag` specifies the maximum number of input blocks (chunks) that are allowed in the fused operation. -Again, this is to limit the number of reads that an individual task must perform. The default is `None`, which means that operations are fused only if they have the same number of tasks. If set to an integer, then this limitation is removed, and tasks with a different number of tasks will be fused - as long as the total number of input blocks does not exceed the maximum. This setting is useful for reductions, and can be set using `functools.partial`: +Again, this is to limit the number of reads that an individual task must perform. If set to `None`, operations are fused only if they have the same number of tasks. If set to an integer (the default is 10), then tasks with a different number of tasks will be fused - as long as the total number of input blocks does not exceed the maximum. This setting is useful for reductions, and can be changed using `functools.partial`: ```python -opt_fn = partial(multiple_inputs_optimize_dag, max_total_num_input_blocks=10) +opt_fn = partial(multiple_inputs_optimize_dag, max_total_num_input_blocks=20) e.visualize(optimize_function=opt_fn) ```