diff --git a/cubed/core/optimization.py b/cubed/core/optimization.py index fc060ce5..268faa78 100644 --- a/cubed/core/optimization.py +++ b/cubed/core/optimization.py @@ -109,6 +109,19 @@ def is_fusable(node_dict): return is_primitive_op(node_dict) and node_dict["primitive_op"].fusable +def num_source_arrays(dag, name): + """Return the number of (non-hidden) arrays that are inputs to an op. + + Hidden arrays are used for internal bookkeeping, are very small virtual arrays + (empty, or offsets for example), and are not shown on the plan visualization. + For these reasons they shouldn't count towards ``max_total_source_arrays``. + """ + nodes = dict(dag.nodes(data=True)) + return sum( + not nodes[array]["hidden"] for array in predecessors_unordered(dag, name) + ) + + def can_fuse_predecessors( dag, name, @@ -145,9 +158,7 @@ def can_fuse_predecessors( # the fused function would be more than an allowed maximum, then don't fuse if len(list(predecessor_ops(dag, name))) > 1: total_source_arrays = sum( - len(list(predecessors_unordered(dag, pre))) - if is_primitive_op(nodes[pre]) - else 1 + num_source_arrays(dag, pre) if is_primitive_op(nodes[pre]) else 1 for pre in predecessor_ops(dag, name) ) if total_source_arrays > max_total_source_arrays: