diff --git a/cubed/core/ops.py b/cubed/core/ops.py index 536b9eea..fcb0dba0 100644 --- a/cubed/core/ops.py +++ b/cubed/core/ops.py @@ -262,6 +262,8 @@ def blockwise( extra_projected_mem = kwargs.pop("extra_projected_mem", 0) + fusable = kwargs.pop("fusable", True) + name = gensym() spec = check_array_specs(arrays) if target_store is None: @@ -281,6 +283,7 @@ def blockwise( in_names=in_names, out_name=name, extra_func_kwargs=extra_func_kwargs, + fusable=fusable, **kwargs, ) plan = Plan._new( @@ -688,6 +691,7 @@ def wrap(*a, block_id=None, **kw): chunks=chunks, extra_source_arrays=args, extra_projected_mem=extra_projected_mem, + fusable=False, # don't allow fusion since side inputs are not accounted for **kwargs, ) diff --git a/cubed/core/optimization.py b/cubed/core/optimization.py index 52acac4c..348f4bcf 100644 --- a/cubed/core/optimization.py +++ b/cubed/core/optimization.py @@ -92,10 +92,12 @@ def predecessor_ops(dag, name): def is_fusable(node_dict): "Return True if a node can be fused." - return "primitive_op" in node_dict + return "primitive_op" in node_dict and node_dict["primitive_op"].fusable -def can_fuse_predecessors(dag, name, *, max_total_nargs=4): +def can_fuse_predecessors( + dag, name, *, max_total_nargs=4, always_fuse=None, never_fuse=None +): nodes = dict(dag.nodes(data=True)) # if node itself can't be fused then there is nothing to fuse @@ -106,6 +108,12 @@ def can_fuse_predecessors(dag, name, *, max_total_nargs=4): if all(not is_fusable(nodes[pre]) for pre in predecessor_ops(dag, name)): return False + # if node is in never_fuse or always_fuse list then it overrides logic below + if never_fuse is not None and name in never_fuse: + return False + if always_fuse is not None and name in always_fuse: + return True + # if there is more than a single predecessor op, and the total number of args to # the fused function would be more than an allowed maximum, then don't fuse if len(list(predecessor_ops(dag, name))) > 1: @@ -126,24 +134,32 @@ def can_fuse_predecessors(dag, name, *, max_total_nargs=4): ) -def fuse_predecessors(dag, name): +def fuse_predecessors( + dag, name, *, max_total_nargs=4, always_fuse=None, never_fuse=None +): """Fuse a node with its immediate predecessors.""" # if can't fuse then return dag unchanged - if not can_fuse_predecessors(dag, name): + if not can_fuse_predecessors( + dag, + name, + max_total_nargs=max_total_nargs, + always_fuse=always_fuse, + never_fuse=never_fuse, + ): return dag nodes = dict(dag.nodes(data=True)) primitive_op = nodes[name]["primitive_op"] - # if a predecessor op has no primitive op then just use None + # if a predecessor has no primitive op then just use None predecessor_primitive_ops = [ nodes[pre]["primitive_op"] if is_fusable(nodes[pre]) else None for pre in predecessor_ops(dag, name) ] - # if a predecessor op has no func then use 1 for nargs + # if a predecessor has no primitive op then use 1 for nargs predecessor_funcs_nargs = [ len(list(predecessors(dag, pre))) if is_fusable(nodes[pre]) else 1 for pre in predecessor_ops(dag, name) @@ -167,12 +183,12 @@ def fuse_predecessors(dag, name): for input in predecessors(dag, name): pre = next(predecessors(dag, input)) if not is_fusable(fused_nodes[pre]): - # if a predecessor is marked as not fusable then don't change the edge + # if a predecessor is not fusable then don't change the edge continue fused_dag.remove_edge(input, name) for pre in predecessor_ops(dag, name): if not is_fusable(fused_nodes[pre]): - # if a predecessor is marked as not fusable then don't change the edge + # if a predecessor is not fusable then don't change the edge continue for input in predecessors(dag, pre): fused_dag.add_edge(input, name) @@ -188,8 +204,33 @@ def fuse_predecessors(dag, name): return fused_dag -def multiple_inputs_optimize_dag(dag): +def multiple_inputs_optimize_dag( + dag, *, max_total_nargs=4, always_fuse=None, never_fuse=None +): """Fuse multiple inputs.""" for name in list(nx.topological_sort(dag)): - dag = fuse_predecessors(dag, name) + dag = fuse_predecessors( + dag, + name, + max_total_nargs=max_total_nargs, + always_fuse=always_fuse, + never_fuse=never_fuse, + ) return dag + + +def fuse_all_optimize_dag(dag): + """Force all operations to be fused.""" + dag = dag.copy() + always_fuse = [op for op in dag.nodes() if op.startswith("op-")] + return multiple_inputs_optimize_dag(dag, always_fuse=always_fuse) + + +def fuse_only_optimize_dag(dag, *, only_fuse=None): + """Force only specified operations to be fused, all others will be left even if they are suitable for fusion.""" + dag = dag.copy() + always_fuse = only_fuse + never_fuse = set(op for op in dag.nodes() if op.startswith("op-")) - set(only_fuse) + return multiple_inputs_optimize_dag( + dag, always_fuse=always_fuse, never_fuse=never_fuse + ) diff --git a/cubed/core/plan.py b/cubed/core/plan.py index 6043371f..7360c26f 100644 --- a/cubed/core/plan.py +++ b/cubed/core/plan.py @@ -447,4 +447,5 @@ def create_zarr_arrays(lazy_zarr_arrays, reserved_mem): projected_mem=projected_mem, reserved_mem=reserved_mem, num_tasks=num_tasks, + fusable=False, ) diff --git a/cubed/primitive/blockwise.py b/cubed/primitive/blockwise.py index d9190510..80519b4a 100644 --- a/cubed/primitive/blockwise.py +++ b/cubed/primitive/blockwise.py @@ -115,6 +115,7 @@ def blockwise( out_name: Optional[str] = None, extra_projected_mem: int = 0, extra_func_kwargs: Optional[Dict[str, Any]] = None, + fusable: bool = True, **kwargs, ): """Apply a function to multiple blocks from multiple inputs, expressed using concise indexing rules. @@ -196,6 +197,7 @@ def blockwise( in_names=in_names, extra_projected_mem=extra_projected_mem, extra_func_kwargs=extra_func_kwargs, + fusable=fusable, **kwargs, ) @@ -213,6 +215,7 @@ def general_blockwise( in_names: Optional[List[str]] = None, extra_projected_mem: int = 0, extra_func_kwargs: Optional[Dict[str, Any]] = None, + fusable: bool = True, **kwargs, ): """A more general form of ``blockwise`` that uses a function to specify the block @@ -307,6 +310,7 @@ def general_blockwise( projected_mem=projected_mem, reserved_mem=reserved_mem, num_tasks=num_tasks, + fusable=fusable, ) @@ -383,6 +387,7 @@ def fused_func(*args): projected_mem=projected_mem, reserved_mem=reserved_mem, num_tasks=num_tasks, + fusable=True, ) @@ -469,6 +474,7 @@ def fused_func(*args): projected_mem=projected_mem, reserved_mem=reserved_mem, num_tasks=num_tasks, + fusable=True, ) diff --git a/cubed/primitive/rechunk.py b/cubed/primitive/rechunk.py index ba28e280..d1a25051 100644 --- a/cubed/primitive/rechunk.py +++ b/cubed/primitive/rechunk.py @@ -208,5 +208,6 @@ def spec_to_primitive_op( projected_mem=projected_mem, reserved_mem=reserved_mem, num_tasks=num_tasks, + fusable=False, write_chunks=spec.write.chunks, ) diff --git a/cubed/primitive/types.py b/cubed/primitive/types.py index f4df42c7..835861c8 100644 --- a/cubed/primitive/types.py +++ b/cubed/primitive/types.py @@ -13,11 +13,25 @@ class PrimitiveOperation: """Encapsulates metadata about a ``blockwise`` or ``rechunk`` primitive operation.""" pipeline: CubedPipeline + """The pipeline that runs this operation.""" + target_array: Any + """The array being computed by this operation.""" + projected_mem: int + """An upper bound of the memory needed to run a task, in bytes.""" + reserved_mem: int + """The memory reserved on a worker for non-data use when running a task, in bytes.""" + num_tasks: int + """The number of tasks needed to run this operation.""" + + fusable: bool = True + """Whether this operation should be considered for fusion.""" + write_chunks: Optional[T_RegularChunks] = None + """The chunk size used by this operation.""" class CubedArrayProxy: diff --git a/cubed/tests/test_optimization.py b/cubed/tests/test_optimization.py index d97588b1..ed4b8699 100644 --- a/cubed/tests/test_optimization.py +++ b/cubed/tests/test_optimization.py @@ -9,7 +9,14 @@ import cubed.array_api as xp from cubed.backend_array_api import namespace as nxp from cubed.core.ops import elemwise -from cubed.core.optimization import fuse_predecessors, gensym +from cubed.core.optimization import ( + fuse_all_optimize_dag, + fuse_only_optimize_dag, + fuse_predecessors, + gensym, + multiple_inputs_optimize_dag, + simple_optimize_dag, +) from cubed.core.plan import arrays_to_plan from cubed.tests.utils import TaskCounter @@ -74,12 +81,14 @@ def test_no_fusion(spec): c = xp.positive(b) d = xp.equal(b, c) + opt_fn = simple_optimize_dag + num_created_arrays = 3 # b, c, d assert d.plan.num_tasks(optimize_graph=False) == num_created_arrays + 3 - assert d.plan.num_tasks(optimize_graph=True) == num_created_arrays + 3 + assert d.plan.num_tasks(optimize_function=opt_fn) == num_created_arrays + 3 task_counter = TaskCounter() - result = d.compute(callbacks=[task_counter]) + result = d.compute(optimize_function=opt_fn, callbacks=[task_counter]) assert task_counter.value == num_created_arrays + 3 assert_array_equal(result, np.ones((2, 2))) @@ -94,12 +103,14 @@ def test_no_fusion_multiple_edges(spec): # this should not be fused under the current logic d = xp.equal(b, c) + opt_fn = simple_optimize_dag + num_created_arrays = 2 # c, d assert d.plan.num_tasks(optimize_graph=False) == num_created_arrays + 2 - assert d.plan.num_tasks(optimize_graph=True) == num_created_arrays + 2 + assert d.plan.num_tasks(optimize_function=opt_fn) == num_created_arrays + 2 task_counter = TaskCounter() - result = d.compute(callbacks=[task_counter]) + result = d.compute(optimize_function=opt_fn, callbacks=[task_counter]) assert task_counter.value == num_created_arrays + 2 assert_array_equal(result, np.full((2, 2), True)) @@ -127,11 +138,16 @@ def custom_optimize_function(dag): ) -def get_optimize_function(arr): +def fuse_one_level(arr): # use fuse_predecessors to test one level of fusion return partial(fuse_predecessors, name=next(arr.plan.dag.predecessors(arr.name))) +def fuse_multiple_levels(*, max_total_nargs=4): + # use multiple_inputs_optimize_dag to test multiple levels of fusion + return partial(multiple_inputs_optimize_dag, max_total_nargs=max_total_nargs) + + # utility functions for testing structural equivalence of dags @@ -198,7 +214,7 @@ def test_fuse_unary_op(spec): b = xp.negative(a) c = xp.negative(b) - opt_fn = get_optimize_function(c) + opt_fn = fuse_one_level(c) c.visualize(optimize_function=opt_fn) @@ -238,7 +254,7 @@ def test_fuse_binary_op(spec): d = xp.negative(b) e = xp.add(c, d) - opt_fn = get_optimize_function(e) + opt_fn = fuse_one_level(e) e.visualize(optimize_function=opt_fn) @@ -279,7 +295,7 @@ def test_fuse_unary_and_binary_op(spec): e = xp.add(b, c) f = xp.add(d, e) - opt_fn = get_optimize_function(f) + opt_fn = fuse_one_level(f) f.visualize(optimize_function=opt_fn) @@ -312,7 +328,7 @@ def test_fuse_mixed_levels(spec): d = xp.add(b, c) e = xp.add(a, d) - opt_fn = get_optimize_function(e) + opt_fn = fuse_one_level(e) e.visualize(optimize_function=opt_fn) @@ -344,7 +360,7 @@ def test_fuse_diamond(spec): c = xp.positive(a) d = xp.add(b, c) - opt_fn = get_optimize_function(d) + opt_fn = fuse_one_level(d) d.visualize(optimize_function=opt_fn) @@ -377,7 +393,7 @@ def test_fuse_mixed_levels_and_diamond(spec): c = xp.positive(b) d = xp.add(b, c) - opt_fn = get_optimize_function(d) + opt_fn = fuse_one_level(d) d.visualize(optimize_function=opt_fn) @@ -408,7 +424,7 @@ def test_fuse_repeated_argument(spec): b = xp.negative(a) c = xp.add(b, b) - opt_fn = get_optimize_function(c) + opt_fn = fuse_one_level(c) c.visualize(optimize_function=opt_fn) @@ -439,7 +455,7 @@ def test_fuse_other_dependents(spec): d = xp.negative(b) # only fuse c; leave d unfused - opt_fn = get_optimize_function(c) + opt_fn = fuse_one_level(c) # note multi-arg forms of visualize and compute below cubed.visualize(c, d, optimize_function=opt_fn) @@ -460,19 +476,84 @@ def test_fuse_other_dependents(spec): assert_array_equal(d_result, np.ones((2, 2))) -# large fan-in +# unary large fan-in # -# a b c d e f g h -> a b c d e f g h -# \ / \ / \ / \ / \ / \ / \ / \ / -# i j k m i j k m -# \ / \ / \ \ / / -# n o \ \ / / -# \ / ----- p ----- +# a b c d e f g h -> a b c d e f g h +# \ \ \ \ / / / / \ \ \ \ / / / / +# \ \ \ \ / / / / \ \ \ \ / / / / +# \ \ \ \ / / / / \ \ \ \ / / / / +# \ \ \ \ / / / / \ \ \ \ / / / / +# ----- i ----- ----- j ----- +# | +# j +# +def test_fuse_unary_large_fan_in(spec): + a = xp.ones((2, 2), chunks=(2, 2), spec=spec) + b = xp.ones((2, 2), chunks=(2, 2), spec=spec) + c = xp.ones((2, 2), chunks=(2, 2), spec=spec) + d = xp.ones((2, 2), chunks=(2, 2), spec=spec) + e = xp.ones((2, 2), chunks=(2, 2), spec=spec) + f = xp.ones((2, 2), chunks=(2, 2), spec=spec) + g = xp.ones((2, 2), chunks=(2, 2), spec=spec) + h = xp.ones((2, 2), chunks=(2, 2), spec=spec) + + # use elemwise and stack since add can only take 2 args + def stack_add(*a): + return nxp.sum(nxp.stack(a), axis=0) + + i = elemwise(stack_add, a, b, c, d, e, f, g, h, dtype=a.dtype) + j = xp.negative(i) + + # max_total_nargs is left at its default (4) which does not limit fusion since j is unary + opt_fn = fuse_one_level(j) + + j.visualize(optimize_function=opt_fn) + + # check structure of optimized dag + expected_fused_dag = create_dag() + add_placeholder_op(expected_fused_dag, (), (a,)) + add_placeholder_op(expected_fused_dag, (), (b,)) + add_placeholder_op(expected_fused_dag, (), (c,)) + add_placeholder_op(expected_fused_dag, (), (d,)) + add_placeholder_op(expected_fused_dag, (), (e,)) + add_placeholder_op(expected_fused_dag, (), (f,)) + add_placeholder_op(expected_fused_dag, (), (g,)) + add_placeholder_op(expected_fused_dag, (), (h,)) + add_placeholder_op( + expected_fused_dag, + ( + a, + b, + c, + d, + e, + f, + g, + h, + ), + (j,), + ) + assert structurally_equivalent( + j.plan.optimize(optimize_function=opt_fn).dag, expected_fused_dag + ) + + result = j.compute(optimize_function=opt_fn) + assert_array_equal(result, -8 * np.ones((2, 2))) + + +# large fan-in default +# +# a b c d e f g h -> a b c d e f g h +# \ / \ / \ / \ / \ \ / / \ \ / / +# i j k m -- n -- -- o -- +# \ / \ / \ / +# n o \ / +# \ / -- p -- # \ / # \ / # p # -def test_fuse_large_fan_in(spec): +def test_fuse_large_fan_in_default(spec): a = xp.ones((2, 2), chunks=(2, 2), spec=spec) b = xp.ones((2, 2), chunks=(2, 2), spec=spec) c = xp.ones((2, 2), chunks=(2, 2), spec=spec) @@ -492,7 +573,8 @@ def test_fuse_large_fan_in(spec): p = xp.add(n, o) - opt_fn = get_optimize_function(p) + # max_total_nargs is left at its default (4) so only one level is fused + opt_fn = fuse_multiple_levels() p.visualize(optimize_function=opt_fn) @@ -506,20 +588,9 @@ def test_fuse_large_fan_in(spec): add_placeholder_op(expected_fused_dag, (), (f,)) add_placeholder_op(expected_fused_dag, (), (g,)) add_placeholder_op(expected_fused_dag, (), (h,)) - add_placeholder_op(expected_fused_dag, (a, b), (i,)) - add_placeholder_op(expected_fused_dag, (c, d), (j,)) - add_placeholder_op(expected_fused_dag, (e, f), (k,)) - add_placeholder_op(expected_fused_dag, (g, h), (m,)) - add_placeholder_op( - expected_fused_dag, - ( - i, - j, - k, - m, - ), - (p,), - ) + add_placeholder_op(expected_fused_dag, (a, b, c, d), (n,)) + add_placeholder_op(expected_fused_dag, (e, f, g, h), (o,)) + add_placeholder_op(expected_fused_dag, (n, o), (p,)) assert structurally_equivalent( p.plan.optimize(optimize_function=opt_fn).dag, expected_fused_dag ) @@ -528,18 +599,19 @@ def test_fuse_large_fan_in(spec): assert_array_equal(result, 8 * np.ones((2, 2))) -# unary large fan-in +# large fan-in override # -# a b c d e f g h -> a b c d e f g h -# \ \ \ \ / / / / \ \ \ \ / / / / -# \ \ \ \ / / / / \ \ \ \ / / / / -# \ \ \ \ / / / / \ \ \ \ / / / / -# \ \ \ \ / / / / \ \ \ \ / / / / -# ----- i ----- ----- j ----- -# | -# j +# a b c d e f g h -> a b c d e f g h +# \ / \ / \ / \ / \ \ \ \ / / / / +# i j k m \ \ \ \ / / / / +# \ / \ / \ \ \ \ / / / / +# n o \ \ \ \ / / / / +# \ / ----- p ----- +# \ / +# \ / +# p # -def test_fuse_unary_large_fan_in(spec): +def test_fuse_large_fan_in_override(spec): a = xp.ones((2, 2), chunks=(2, 2), spec=spec) b = xp.ones((2, 2), chunks=(2, 2), spec=spec) c = xp.ones((2, 2), chunks=(2, 2), spec=spec) @@ -549,16 +621,20 @@ def test_fuse_unary_large_fan_in(spec): g = xp.ones((2, 2), chunks=(2, 2), spec=spec) h = xp.ones((2, 2), chunks=(2, 2), spec=spec) - # use elemwise and stack since add can only take 2 args - def stack_add(*a): - return nxp.sum(nxp.stack(a), axis=0) + i = xp.add(a, b) + j = xp.add(c, d) + k = xp.add(e, f) + m = xp.add(g, h) - i = elemwise(stack_add, a, b, c, d, e, f, g, h, dtype=a.dtype) - j = xp.negative(i) + n = xp.add(i, j) + o = xp.add(k, m) + + p = xp.add(n, o) - opt_fn = get_optimize_function(j) + # max_total_nargs is overriden so multiple levels are fused + opt_fn = fuse_multiple_levels(max_total_nargs=8) - j.visualize(optimize_function=opt_fn) + p.visualize(optimize_function=opt_fn) # check structure of optimized dag expected_fused_dag = create_dag() @@ -582,11 +658,49 @@ def stack_add(*a): g, h, ), - (j,), + (p,), ) assert structurally_equivalent( - j.plan.optimize(optimize_function=opt_fn).dag, expected_fused_dag + p.plan.optimize(optimize_function=opt_fn).dag, expected_fused_dag ) - result = j.compute(optimize_function=opt_fn) - assert_array_equal(result, -8 * np.ones((2, 2))) + result = p.compute(optimize_function=opt_fn) + assert_array_equal(result, 8 * np.ones((2, 2))) + + # now force everything to be fused with fuse_all_optimize_dag + # note that max_total_nargs is *not* set + opt_fn = fuse_all_optimize_dag + + assert structurally_equivalent( + p.plan.optimize(optimize_function=opt_fn).dag, expected_fused_dag + ) + + result = p.compute(optimize_function=opt_fn) + assert_array_equal(result, 8 * np.ones((2, 2))) + + +def test_fuse_only_optimize_dag(spec): + a = xp.ones((2, 2), chunks=(2, 2), spec=spec) + b = xp.negative(a) + c = xp.negative(b) + d = xp.negative(c) + + # only fuse d (with c) + # b should remain un-fused, even though it is fusable + op_name = next(d.plan.dag.predecessors(d.name)) + opt_fn = partial(fuse_only_optimize_dag, only_fuse=[op_name]) + + c.visualize(optimize_function=opt_fn) + + # check structure of optimized dag + expected_fused_dag = create_dag() + add_placeholder_op(expected_fused_dag, (), (a,)) + add_placeholder_op(expected_fused_dag, (a,), (b,)) + add_placeholder_op(expected_fused_dag, (b,), (d,)) + assert structurally_equivalent( + d.plan.optimize(optimize_function=opt_fn).dag, + expected_fused_dag, + ) + + result = d.compute(optimize_function=opt_fn) + assert_array_equal(result, -np.ones((2, 2)))