Skip to content

Commit

Permalink
More control over operation fusion (#355)
Browse files Browse the repository at this point in the history
* Document PrimitiveOperation

* Allow a primitive operation to be marked as 'fusable'

* Add mechanism to manually fuse operations

* Add fuse_all_optimize_dag and fuse_only_optimize_dag

* Explicitly set simple_optimize_dag for 'no fusion' tests
  • Loading branch information
tomwhite authored Jan 22, 2024
1 parent 1794d87 commit 929f52c
Show file tree
Hide file tree
Showing 7 changed files with 249 additions and 68 deletions.
4 changes: 4 additions & 0 deletions cubed/core/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,6 +262,8 @@ def blockwise(

extra_projected_mem = kwargs.pop("extra_projected_mem", 0)

fusable = kwargs.pop("fusable", True)

name = gensym()
spec = check_array_specs(arrays)
if target_store is None:
Expand All @@ -281,6 +283,7 @@ def blockwise(
in_names=in_names,
out_name=name,
extra_func_kwargs=extra_func_kwargs,
fusable=fusable,
**kwargs,
)
plan = Plan._new(
Expand Down Expand Up @@ -688,6 +691,7 @@ def wrap(*a, block_id=None, **kw):
chunks=chunks,
extra_source_arrays=args,
extra_projected_mem=extra_projected_mem,
fusable=False, # don't allow fusion since side inputs are not accounted for
**kwargs,
)

Expand Down
61 changes: 51 additions & 10 deletions cubed/core/optimization.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,10 +92,12 @@ def predecessor_ops(dag, name):

def is_fusable(node_dict):
"Return True if a node can be fused."
return "primitive_op" in node_dict
return "primitive_op" in node_dict and node_dict["primitive_op"].fusable


def can_fuse_predecessors(dag, name, *, max_total_nargs=4):
def can_fuse_predecessors(
dag, name, *, max_total_nargs=4, always_fuse=None, never_fuse=None
):
nodes = dict(dag.nodes(data=True))

# if node itself can't be fused then there is nothing to fuse
Expand All @@ -106,6 +108,12 @@ def can_fuse_predecessors(dag, name, *, max_total_nargs=4):
if all(not is_fusable(nodes[pre]) for pre in predecessor_ops(dag, name)):
return False

# if node is in never_fuse or always_fuse list then it overrides logic below
if never_fuse is not None and name in never_fuse:
return False
if always_fuse is not None and name in always_fuse:
return True

# if there is more than a single predecessor op, and the total number of args to
# the fused function would be more than an allowed maximum, then don't fuse
if len(list(predecessor_ops(dag, name))) > 1:
Expand All @@ -126,24 +134,32 @@ def can_fuse_predecessors(dag, name, *, max_total_nargs=4):
)


def fuse_predecessors(dag, name):
def fuse_predecessors(
dag, name, *, max_total_nargs=4, always_fuse=None, never_fuse=None
):
"""Fuse a node with its immediate predecessors."""

# if can't fuse then return dag unchanged
if not can_fuse_predecessors(dag, name):
if not can_fuse_predecessors(
dag,
name,
max_total_nargs=max_total_nargs,
always_fuse=always_fuse,
never_fuse=never_fuse,
):
return dag

nodes = dict(dag.nodes(data=True))

primitive_op = nodes[name]["primitive_op"]

# if a predecessor op has no primitive op then just use None
# if a predecessor has no primitive op then just use None
predecessor_primitive_ops = [
nodes[pre]["primitive_op"] if is_fusable(nodes[pre]) else None
for pre in predecessor_ops(dag, name)
]

# if a predecessor op has no func then use 1 for nargs
# if a predecessor has no primitive op then use 1 for nargs
predecessor_funcs_nargs = [
len(list(predecessors(dag, pre))) if is_fusable(nodes[pre]) else 1
for pre in predecessor_ops(dag, name)
Expand All @@ -167,12 +183,12 @@ def fuse_predecessors(dag, name):
for input in predecessors(dag, name):
pre = next(predecessors(dag, input))
if not is_fusable(fused_nodes[pre]):
# if a predecessor is marked as not fusable then don't change the edge
# if a predecessor is not fusable then don't change the edge
continue
fused_dag.remove_edge(input, name)
for pre in predecessor_ops(dag, name):
if not is_fusable(fused_nodes[pre]):
# if a predecessor is marked as not fusable then don't change the edge
# if a predecessor is not fusable then don't change the edge
continue
for input in predecessors(dag, pre):
fused_dag.add_edge(input, name)
Expand All @@ -188,8 +204,33 @@ def fuse_predecessors(dag, name):
return fused_dag


def multiple_inputs_optimize_dag(dag):
def multiple_inputs_optimize_dag(
dag, *, max_total_nargs=4, always_fuse=None, never_fuse=None
):
"""Fuse multiple inputs."""
for name in list(nx.topological_sort(dag)):
dag = fuse_predecessors(dag, name)
dag = fuse_predecessors(
dag,
name,
max_total_nargs=max_total_nargs,
always_fuse=always_fuse,
never_fuse=never_fuse,
)
return dag


def fuse_all_optimize_dag(dag):
"""Force all operations to be fused."""
dag = dag.copy()
always_fuse = [op for op in dag.nodes() if op.startswith("op-")]
return multiple_inputs_optimize_dag(dag, always_fuse=always_fuse)


def fuse_only_optimize_dag(dag, *, only_fuse=None):
"""Force only specified operations to be fused, all others will be left even if they are suitable for fusion."""
dag = dag.copy()
always_fuse = only_fuse
never_fuse = set(op for op in dag.nodes() if op.startswith("op-")) - set(only_fuse)
return multiple_inputs_optimize_dag(
dag, always_fuse=always_fuse, never_fuse=never_fuse
)
1 change: 1 addition & 0 deletions cubed/core/plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -447,4 +447,5 @@ def create_zarr_arrays(lazy_zarr_arrays, reserved_mem):
projected_mem=projected_mem,
reserved_mem=reserved_mem,
num_tasks=num_tasks,
fusable=False,
)
6 changes: 6 additions & 0 deletions cubed/primitive/blockwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,7 @@ def blockwise(
out_name: Optional[str] = None,
extra_projected_mem: int = 0,
extra_func_kwargs: Optional[Dict[str, Any]] = None,
fusable: bool = True,
**kwargs,
):
"""Apply a function to multiple blocks from multiple inputs, expressed using concise indexing rules.
Expand Down Expand Up @@ -196,6 +197,7 @@ def blockwise(
in_names=in_names,
extra_projected_mem=extra_projected_mem,
extra_func_kwargs=extra_func_kwargs,
fusable=fusable,
**kwargs,
)

Expand All @@ -213,6 +215,7 @@ def general_blockwise(
in_names: Optional[List[str]] = None,
extra_projected_mem: int = 0,
extra_func_kwargs: Optional[Dict[str, Any]] = None,
fusable: bool = True,
**kwargs,
):
"""A more general form of ``blockwise`` that uses a function to specify the block
Expand Down Expand Up @@ -307,6 +310,7 @@ def general_blockwise(
projected_mem=projected_mem,
reserved_mem=reserved_mem,
num_tasks=num_tasks,
fusable=fusable,
)


Expand Down Expand Up @@ -383,6 +387,7 @@ def fused_func(*args):
projected_mem=projected_mem,
reserved_mem=reserved_mem,
num_tasks=num_tasks,
fusable=True,
)


Expand Down Expand Up @@ -469,6 +474,7 @@ def fused_func(*args):
projected_mem=projected_mem,
reserved_mem=reserved_mem,
num_tasks=num_tasks,
fusable=True,
)


Expand Down
1 change: 1 addition & 0 deletions cubed/primitive/rechunk.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,5 +208,6 @@ def spec_to_primitive_op(
projected_mem=projected_mem,
reserved_mem=reserved_mem,
num_tasks=num_tasks,
fusable=False,
write_chunks=spec.write.chunks,
)
14 changes: 14 additions & 0 deletions cubed/primitive/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,25 @@ class PrimitiveOperation:
"""Encapsulates metadata about a ``blockwise`` or ``rechunk`` primitive operation."""

pipeline: CubedPipeline
"""The pipeline that runs this operation."""

target_array: Any
"""The array being computed by this operation."""

projected_mem: int
"""An upper bound of the memory needed to run a task, in bytes."""

reserved_mem: int
"""The memory reserved on a worker for non-data use when running a task, in bytes."""

num_tasks: int
"""The number of tasks needed to run this operation."""

fusable: bool = True
"""Whether this operation should be considered for fusion."""

write_chunks: Optional[T_RegularChunks] = None
"""The chunk size used by this operation."""


class CubedArrayProxy:
Expand Down
Loading

0 comments on commit 929f52c

Please sign in to comment.