diff --git a/cubed/core/optimization.py b/cubed/core/optimization.py index 507f896f..a032854f 100644 --- a/cubed/core/optimization.py +++ b/cubed/core/optimization.py @@ -1,3 +1,5 @@ +import logging + import networkx as nx from cubed.primitive.blockwise import ( @@ -7,6 +9,8 @@ fuse_multiple, ) +logger = logging.getLogger(__name__) + def simple_optimize_dag(dag): """Apply map blocks fusion.""" @@ -108,16 +112,20 @@ def can_fuse_predecessors( # if node itself can't be fused then there is nothing to fuse if not is_fusable(nodes[name]): + logger.debug("can't fuse %s since it is not fusable", name) return False # if no predecessor ops can be fused then there is nothing to fuse if all(not is_fusable(nodes[pre]) for pre in predecessor_ops(dag, name)): + logger.debug("can't fuse %s since no predecessor ops can be fused", name) return False # if node is in never_fuse or always_fuse list then it overrides logic below if never_fuse is not None and name in never_fuse: + logger.debug("can't fuse %s since it is in 'never_fuse'", name) return False if always_fuse is not None and name in always_fuse: + logger.debug("can fuse %s since it is in 'always_fuse'", name) return True # if there is more than a single predecessor op, and the total number of source arrays to @@ -128,6 +136,12 @@ def can_fuse_predecessors( for pre in predecessor_ops(dag, name) ) if total_source_arrays > max_total_source_arrays: + logger.debug( + "can't fuse %s since total number of source arrays (%s) exceeds max (%s)", + name, + total_source_arrays, + max_total_source_arrays, + ) return False predecessor_primitive_ops = [ @@ -136,6 +150,7 @@ def can_fuse_predecessors( if is_fusable(nodes[pre]) ] return can_fuse_multiple_primitive_ops( + name, nodes[name]["primitive_op"], predecessor_primitive_ops, max_total_num_input_blocks=max_total_num_input_blocks, @@ -219,6 +234,8 @@ def multiple_inputs_optimize_dag( ): """Fuse multiple inputs.""" for name in list(nx.topological_sort(dag)): + if name.startswith("array-"): + continue dag = fuse_predecessors( dag, name, diff --git a/cubed/primitive/blockwise.py b/cubed/primitive/blockwise.py index 71d69dc0..d1486249 100644 --- a/cubed/primitive/blockwise.py +++ b/cubed/primitive/blockwise.py @@ -1,4 +1,5 @@ import itertools +import logging import math from collections.abc import Iterator from dataclasses import dataclass @@ -23,6 +24,9 @@ from .types import CubedArrayProxy, MemoryModeller, PrimitiveOperation +logger = logging.getLogger(__name__) + + sym_counter = 0 @@ -352,6 +356,7 @@ def can_fuse_primitive_ops( def can_fuse_multiple_primitive_ops( + name: str, primitive_op: PrimitiveOperation, predecessor_primitive_ops: List[PrimitiveOperation], *, @@ -362,7 +367,14 @@ def can_fuse_multiple_primitive_ops( ): # If the peak projected memory for running all the predecessor ops in # order is larger than allowed_mem then we can't fuse. - if peak_projected_mem(predecessor_primitive_ops) > primitive_op.allowed_mem: + peak_projected = peak_projected_mem(predecessor_primitive_ops) + if peak_projected > primitive_op.allowed_mem: + logger.debug( + "can't fuse %s since peak projected memory for predecessor ops (%s) is greater than allowed (%s)", + name, + peak_projected, + primitive_op.allowed_mem, + ) return False # If the number of input blocks for each input is not uniform, then we # can't fuse. (This should never happen since all operations are @@ -370,19 +382,52 @@ def can_fuse_multiple_primitive_ops( # topological order.) num_input_blocks = primitive_op.pipeline.config.num_input_blocks if not all(num_input_blocks[0] == n for n in num_input_blocks): + logger.debug( + "can't fuse %s since number of input blocks for each input is not uniform: %s", + name, + num_input_blocks, + ) return False if max_total_num_input_blocks is None: # If max total input blocks not specified, then only fuse if num # tasks of predecessor ops match. - return all( + ret = all( primitive_op.num_tasks == p.num_tasks for p in predecessor_primitive_ops ) + if ret: + logger.debug( + "can fuse %s since num tasks of predecessor ops match", name + ) + else: + logger.debug( + "can't fuse %s since num tasks of predecessor ops do not match", + name, + ) + return ret else: total_num_input_blocks = 0 for ni, p in zip(num_input_blocks, predecessor_primitive_ops): for nj in p.pipeline.config.num_input_blocks: total_num_input_blocks += ni * nj - return total_num_input_blocks <= max_total_num_input_blocks + ret = total_num_input_blocks <= max_total_num_input_blocks + if ret: + logger.debug( + "can fuse %s since total number of input blocks (%s) does not exceed max (%s)", + name, + total_num_input_blocks, + max_total_num_input_blocks, + ) + else: + logger.debug( + "can't fuse %s since total number of input blocks (%s) exceeds max (%s)", + name, + total_num_input_blocks, + max_total_num_input_blocks, + ) + return ret + logger.debug( + "can't fuse %s since primitive op and predecessors are not all candidates", name + ) return False