Skip to content

Commit

Permalink
Add debug log statements to optimizer to see why nodes are fused
Browse files Browse the repository at this point in the history
  • Loading branch information
tomwhite committed Feb 19, 2024
1 parent 7056a72 commit 7df4910
Show file tree
Hide file tree
Showing 2 changed files with 65 additions and 3 deletions.
17 changes: 17 additions & 0 deletions cubed/core/optimization.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import logging

import networkx as nx

from cubed.primitive.blockwise import (
Expand All @@ -7,6 +9,8 @@
fuse_multiple,
)

logger = logging.getLogger(__name__)


def simple_optimize_dag(dag):
"""Apply map blocks fusion."""
Expand Down Expand Up @@ -108,16 +112,20 @@ def can_fuse_predecessors(

# if node itself can't be fused then there is nothing to fuse
if not is_fusable(nodes[name]):
logger.debug("can't fuse %s since it is not fusable", name)
return False

# if no predecessor ops can be fused then there is nothing to fuse
if all(not is_fusable(nodes[pre]) for pre in predecessor_ops(dag, name)):
logger.debug("can't fuse %s since no predecessor ops can be fused", name)
return False

# if node is in never_fuse or always_fuse list then it overrides logic below
if never_fuse is not None and name in never_fuse:
logger.debug("can't fuse %s since it is in 'never_fuse'", name)
return False
if always_fuse is not None and name in always_fuse:
logger.debug("can fuse %s since it is in 'always_fuse'", name)
return True

# if there is more than a single predecessor op, and the total number of source arrays to
Expand All @@ -128,6 +136,12 @@ def can_fuse_predecessors(
for pre in predecessor_ops(dag, name)
)
if total_source_arrays > max_total_source_arrays:
logger.debug(
"can't fuse %s since total number of source arrays (%s) exceeds max (%s)",
name,
total_source_arrays,
max_total_source_arrays,
)
return False

predecessor_primitive_ops = [
Expand All @@ -136,6 +150,7 @@ def can_fuse_predecessors(
if is_fusable(nodes[pre])
]
return can_fuse_multiple_primitive_ops(
name,
nodes[name]["primitive_op"],
predecessor_primitive_ops,
max_total_num_input_blocks=max_total_num_input_blocks,
Expand Down Expand Up @@ -219,6 +234,8 @@ def multiple_inputs_optimize_dag(
):
"""Fuse multiple inputs."""
for name in list(nx.topological_sort(dag)):
if name.startswith("array-"):
continue
dag = fuse_predecessors(
dag,
name,
Expand Down
51 changes: 48 additions & 3 deletions cubed/primitive/blockwise.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import itertools
import logging
import math
from collections.abc import Iterator
from dataclasses import dataclass
Expand All @@ -23,6 +24,9 @@

from .types import CubedArrayProxy, MemoryModeller, PrimitiveOperation

logger = logging.getLogger(__name__)


sym_counter = 0


Expand Down Expand Up @@ -352,6 +356,7 @@ def can_fuse_primitive_ops(


def can_fuse_multiple_primitive_ops(
name: str,
primitive_op: PrimitiveOperation,
predecessor_primitive_ops: List[PrimitiveOperation],
*,
Expand All @@ -362,27 +367,67 @@ def can_fuse_multiple_primitive_ops(
):
# If the peak projected memory for running all the predecessor ops in
# order is larger than allowed_mem then we can't fuse.
if peak_projected_mem(predecessor_primitive_ops) > primitive_op.allowed_mem:
peak_projected = peak_projected_mem(predecessor_primitive_ops)
if peak_projected > primitive_op.allowed_mem:
logger.debug(
"can't fuse %s since peak projected memory for predecessor ops (%s) is greater than allowed (%s)",
name,
peak_projected,
primitive_op.allowed_mem,
)
return False
# If the number of input blocks for each input is not uniform, then we
# can't fuse. (This should never happen since all operations are
# currently uniform, and fused operations are too if fuse is applied in
# topological order.)
num_input_blocks = primitive_op.pipeline.config.num_input_blocks
if not all(num_input_blocks[0] == n for n in num_input_blocks):
logger.debug(
"can't fuse %s since number of input blocks for each input is not uniform: %s",
name,
num_input_blocks,
)
return False
if max_total_num_input_blocks is None:
# If max total input blocks not specified, then only fuse if num
# tasks of predecessor ops match.
return all(
ret = all(
primitive_op.num_tasks == p.num_tasks for p in predecessor_primitive_ops
)
if ret:
logger.debug(
"can fuse %s since num tasks of predecessor ops match", name
)
else:
logger.debug(
"can't fuse %s since num tasks of predecessor ops do not match",
name,
)
return ret
else:
total_num_input_blocks = 0
for ni, p in zip(num_input_blocks, predecessor_primitive_ops):
for nj in p.pipeline.config.num_input_blocks:
total_num_input_blocks += ni * nj
return total_num_input_blocks <= max_total_num_input_blocks
ret = total_num_input_blocks <= max_total_num_input_blocks
if ret:
logger.debug(
"can fuse %s since total number of input blocks (%s) does not exceed max (%s)",
name,
total_num_input_blocks,
max_total_num_input_blocks,
)
else:
logger.debug(
"can't fuse %s since total number of input blocks (%s) exceeds max (%s)",
name,
total_num_input_blocks,
max_total_num_input_blocks,
)
return ret
logger.debug(
"can't fuse %s since primitive op and predecessors are not all candidates", name
)
return False


Expand Down

0 comments on commit 7df4910

Please sign in to comment.