Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

More control over operation fusion #355

Merged
merged 5 commits into from
Jan 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions cubed/core/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,6 +262,8 @@ def blockwise(

extra_projected_mem = kwargs.pop("extra_projected_mem", 0)

fusable = kwargs.pop("fusable", True)

name = gensym()
spec = check_array_specs(arrays)
if target_store is None:
Expand All @@ -281,6 +283,7 @@ def blockwise(
in_names=in_names,
out_name=name,
extra_func_kwargs=extra_func_kwargs,
fusable=fusable,
**kwargs,
)
plan = Plan._new(
Expand Down Expand Up @@ -688,6 +691,7 @@ def wrap(*a, block_id=None, **kw):
chunks=chunks,
extra_source_arrays=args,
extra_projected_mem=extra_projected_mem,
fusable=False, # don't allow fusion since side inputs are not accounted for
**kwargs,
)

Expand Down
61 changes: 51 additions & 10 deletions cubed/core/optimization.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,10 +92,12 @@ def predecessor_ops(dag, name):

def is_fusable(node_dict):
"Return True if a node can be fused."
return "primitive_op" in node_dict
return "primitive_op" in node_dict and node_dict["primitive_op"].fusable


def can_fuse_predecessors(dag, name, *, max_total_nargs=4):
def can_fuse_predecessors(
dag, name, *, max_total_nargs=4, always_fuse=None, never_fuse=None
):
nodes = dict(dag.nodes(data=True))

# if node itself can't be fused then there is nothing to fuse
Expand All @@ -106,6 +108,12 @@ def can_fuse_predecessors(dag, name, *, max_total_nargs=4):
if all(not is_fusable(nodes[pre]) for pre in predecessor_ops(dag, name)):
return False

# if node is in never_fuse or always_fuse list then it overrides logic below
if never_fuse is not None and name in never_fuse:
return False
if always_fuse is not None and name in always_fuse:
return True

# if there is more than a single predecessor op, and the total number of args to
# the fused function would be more than an allowed maximum, then don't fuse
if len(list(predecessor_ops(dag, name))) > 1:
Expand All @@ -126,24 +134,32 @@ def can_fuse_predecessors(dag, name, *, max_total_nargs=4):
)


def fuse_predecessors(dag, name):
def fuse_predecessors(
dag, name, *, max_total_nargs=4, always_fuse=None, never_fuse=None
):
"""Fuse a node with its immediate predecessors."""

# if can't fuse then return dag unchanged
if not can_fuse_predecessors(dag, name):
if not can_fuse_predecessors(
dag,
name,
max_total_nargs=max_total_nargs,
always_fuse=always_fuse,
never_fuse=never_fuse,
):
return dag

nodes = dict(dag.nodes(data=True))

primitive_op = nodes[name]["primitive_op"]

# if a predecessor op has no primitive op then just use None
# if a predecessor has no primitive op then just use None
predecessor_primitive_ops = [
nodes[pre]["primitive_op"] if is_fusable(nodes[pre]) else None
for pre in predecessor_ops(dag, name)
]

# if a predecessor op has no func then use 1 for nargs
# if a predecessor has no primitive op then use 1 for nargs
predecessor_funcs_nargs = [
len(list(predecessors(dag, pre))) if is_fusable(nodes[pre]) else 1
for pre in predecessor_ops(dag, name)
Expand All @@ -167,12 +183,12 @@ def fuse_predecessors(dag, name):
for input in predecessors(dag, name):
pre = next(predecessors(dag, input))
if not is_fusable(fused_nodes[pre]):
# if a predecessor is marked as not fusable then don't change the edge
# if a predecessor is not fusable then don't change the edge
continue
fused_dag.remove_edge(input, name)
for pre in predecessor_ops(dag, name):
if not is_fusable(fused_nodes[pre]):
# if a predecessor is marked as not fusable then don't change the edge
# if a predecessor is not fusable then don't change the edge
continue
for input in predecessors(dag, pre):
fused_dag.add_edge(input, name)
Expand All @@ -188,8 +204,33 @@ def fuse_predecessors(dag, name):
return fused_dag


def multiple_inputs_optimize_dag(dag):
def multiple_inputs_optimize_dag(
dag, *, max_total_nargs=4, always_fuse=None, never_fuse=None
):
"""Fuse multiple inputs."""
for name in list(nx.topological_sort(dag)):
dag = fuse_predecessors(dag, name)
dag = fuse_predecessors(
dag,
name,
max_total_nargs=max_total_nargs,
always_fuse=always_fuse,
never_fuse=never_fuse,
)
return dag


def fuse_all_optimize_dag(dag):
"""Force all operations to be fused."""
dag = dag.copy()
always_fuse = [op for op in dag.nodes() if op.startswith("op-")]
return multiple_inputs_optimize_dag(dag, always_fuse=always_fuse)


def fuse_only_optimize_dag(dag, *, only_fuse=None):
"""Force only specified operations to be fused, all others will be left even if they are suitable for fusion."""
dag = dag.copy()
always_fuse = only_fuse
never_fuse = set(op for op in dag.nodes() if op.startswith("op-")) - set(only_fuse)
return multiple_inputs_optimize_dag(
dag, always_fuse=always_fuse, never_fuse=never_fuse
)
1 change: 1 addition & 0 deletions cubed/core/plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -447,4 +447,5 @@ def create_zarr_arrays(lazy_zarr_arrays, reserved_mem):
projected_mem=projected_mem,
reserved_mem=reserved_mem,
num_tasks=num_tasks,
fusable=False,
)
6 changes: 6 additions & 0 deletions cubed/primitive/blockwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,7 @@ def blockwise(
out_name: Optional[str] = None,
extra_projected_mem: int = 0,
extra_func_kwargs: Optional[Dict[str, Any]] = None,
fusable: bool = True,
**kwargs,
):
"""Apply a function to multiple blocks from multiple inputs, expressed using concise indexing rules.
Expand Down Expand Up @@ -196,6 +197,7 @@ def blockwise(
in_names=in_names,
extra_projected_mem=extra_projected_mem,
extra_func_kwargs=extra_func_kwargs,
fusable=fusable,
**kwargs,
)

Expand All @@ -213,6 +215,7 @@ def general_blockwise(
in_names: Optional[List[str]] = None,
extra_projected_mem: int = 0,
extra_func_kwargs: Optional[Dict[str, Any]] = None,
fusable: bool = True,
**kwargs,
):
"""A more general form of ``blockwise`` that uses a function to specify the block
Expand Down Expand Up @@ -307,6 +310,7 @@ def general_blockwise(
projected_mem=projected_mem,
reserved_mem=reserved_mem,
num_tasks=num_tasks,
fusable=fusable,
)


Expand Down Expand Up @@ -383,6 +387,7 @@ def fused_func(*args):
projected_mem=projected_mem,
reserved_mem=reserved_mem,
num_tasks=num_tasks,
fusable=True,
)


Expand Down Expand Up @@ -469,6 +474,7 @@ def fused_func(*args):
projected_mem=projected_mem,
reserved_mem=reserved_mem,
num_tasks=num_tasks,
fusable=True,
)


Expand Down
1 change: 1 addition & 0 deletions cubed/primitive/rechunk.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,5 +208,6 @@ def spec_to_primitive_op(
projected_mem=projected_mem,
reserved_mem=reserved_mem,
num_tasks=num_tasks,
fusable=False,
write_chunks=spec.write.chunks,
)
14 changes: 14 additions & 0 deletions cubed/primitive/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,25 @@ class PrimitiveOperation:
"""Encapsulates metadata about a ``blockwise`` or ``rechunk`` primitive operation."""

pipeline: CubedPipeline
"""The pipeline that runs this operation."""

target_array: Any
"""The array being computed by this operation."""

projected_mem: int
"""An upper bound of the memory needed to run a task, in bytes."""

reserved_mem: int
"""The memory reserved on a worker for non-data use when running a task, in bytes."""

num_tasks: int
"""The number of tasks needed to run this operation."""

fusable: bool = True
"""Whether this operation should be considered for fusion."""

write_chunks: Optional[T_RegularChunks] = None
"""The chunk size used by this operation."""


class CubedArrayProxy:
Expand Down
Loading
Loading