diff --git a/cubed/core/optimization.py b/cubed/core/optimization.py index fe532562..507f896f 100644 --- a/cubed/core/optimization.py +++ b/cubed/core/optimization.py @@ -96,7 +96,13 @@ def is_fusable(node_dict): def can_fuse_predecessors( - dag, name, *, max_total_source_arrays=4, always_fuse=None, never_fuse=None + dag, + name, + *, + max_total_source_arrays=4, + max_total_num_input_blocks=None, + always_fuse=None, + never_fuse=None, ): nodes = dict(dag.nodes(data=True)) @@ -130,12 +136,20 @@ def can_fuse_predecessors( if is_fusable(nodes[pre]) ] return can_fuse_multiple_primitive_ops( - nodes[name]["primitive_op"], *predecessor_primitive_ops + nodes[name]["primitive_op"], + predecessor_primitive_ops, + max_total_num_input_blocks=max_total_num_input_blocks, ) def fuse_predecessors( - dag, name, *, max_total_source_arrays=4, always_fuse=None, never_fuse=None + dag, + name, + *, + max_total_source_arrays=4, + max_total_num_input_blocks=None, + always_fuse=None, + never_fuse=None, ): """Fuse a node with its immediate predecessors.""" @@ -144,6 +158,7 @@ def fuse_predecessors( dag, name, max_total_source_arrays=max_total_source_arrays, + max_total_num_input_blocks=max_total_num_input_blocks, always_fuse=always_fuse, never_fuse=never_fuse, ): @@ -195,7 +210,12 @@ def fuse_predecessors( def multiple_inputs_optimize_dag( - dag, *, max_total_source_arrays=4, always_fuse=None, never_fuse=None + dag, + *, + max_total_source_arrays=4, + max_total_num_input_blocks=None, + always_fuse=None, + never_fuse=None, ): """Fuse multiple inputs.""" for name in list(nx.topological_sort(dag)): @@ -203,6 +223,7 @@ def multiple_inputs_optimize_dag( dag, name, max_total_source_arrays=max_total_source_arrays, + max_total_num_input_blocks=max_total_num_input_blocks, always_fuse=always_fuse, never_fuse=never_fuse, ) diff --git a/cubed/primitive/blockwise.py b/cubed/primitive/blockwise.py index 14238278..72f6137b 100644 --- a/cubed/primitive/blockwise.py +++ b/cubed/primitive/blockwise.py @@ -352,7 +352,10 @@ def can_fuse_primitive_ops( def can_fuse_multiple_primitive_ops( - primitive_op: PrimitiveOperation, *predecessor_primitive_ops: PrimitiveOperation + primitive_op: PrimitiveOperation, + predecessor_primitive_ops: List[PrimitiveOperation], + *, + max_total_num_input_blocks: Optional[int] = None, ) -> bool: if is_fuse_candidate(primitive_op) and all( is_fuse_candidate(p) for p in predecessor_primitive_ops @@ -368,9 +371,18 @@ def can_fuse_multiple_primitive_ops( num_input_blocks = primitive_op.pipeline.config.num_input_blocks if not all(num_input_blocks[0] == n for n in num_input_blocks): return False - return all( - primitive_op.num_tasks == p.num_tasks for p in predecessor_primitive_ops - ) + if max_total_num_input_blocks is None: + # If max total input blocks not specified, then only fuse if num + # tasks of predecessor ops match. + return all( + primitive_op.num_tasks == p.num_tasks for p in predecessor_primitive_ops + ) + else: + total_num_input_blocks = 0 + for ni, p in zip(num_input_blocks, predecessor_primitive_ops): + for nj in p.pipeline.config.num_input_blocks: + total_num_input_blocks += ni * nj + return total_num_input_blocks <= max_total_num_input_blocks return False diff --git a/cubed/tests/test_optimization.py b/cubed/tests/test_optimization.py index 4cf69b01..8e069456 100644 --- a/cubed/tests/test_optimization.py +++ b/cubed/tests/test_optimization.py @@ -152,10 +152,12 @@ def fuse_one_level(arr, *, always_fuse=None): ) -def fuse_multiple_levels(*, max_total_source_arrays=4): +def fuse_multiple_levels(*, max_total_source_arrays=4, max_total_num_input_blocks=None): # use multiple_inputs_optimize_dag to test multiple levels of fusion return partial( - multiple_inputs_optimize_dag, max_total_source_arrays=max_total_source_arrays + multiple_inputs_optimize_dag, + max_total_source_arrays=max_total_source_arrays, + max_total_num_input_blocks=max_total_num_input_blocks, ) @@ -775,9 +777,8 @@ def test_fuse_merge_chunks_unary(spec): b = xp.negative(a) c = merge_chunks_new(b, chunks=(3, 2)) - # force c to fuse - last_op = sorted(c.plan.dag.nodes())[-1] - opt_fn = fuse_one_level(c, always_fuse=[last_op]) + # specify max_total_num_input_blocks to force c to fuse + opt_fn = fuse_multiple_levels(max_total_num_input_blocks=3) c.visualize(optimize_function=opt_fn) @@ -809,9 +810,8 @@ def test_fuse_merge_chunks_binary(spec): c = xp.add(a, b) d = merge_chunks_new(c, chunks=(3, 2)) - # force d to fuse - last_op = sorted(d.plan.dag.nodes())[-1] - opt_fn = fuse_one_level(d, always_fuse=[last_op]) + # specify max_total_num_input_blocks to force d to fuse + opt_fn = fuse_multiple_levels(max_total_num_input_blocks=6) d.visualize(optimize_function=opt_fn)