Skip to content

Commit

Permalink
Refactor pipeline to reduce coupling between runtime and primitive la…
Browse files Browse the repository at this point in the history
…yers (#352)

Introduce PrimitiveOperation that encapsulates information for blockwise and rechunk that the runtime layer does not need
  • Loading branch information
tomwhite authored Jan 19, 2024
1 parent 49df637 commit 30df284
Show file tree
Hide file tree
Showing 11 changed files with 226 additions and 182 deletions.
44 changes: 22 additions & 22 deletions cubed/core/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,7 +266,7 @@ def blockwise(
spec = check_array_specs(arrays)
if target_store is None:
target_store = new_temp_path(name=name, spec=spec)
pipeline = primitive_blockwise(
op = primitive_blockwise(
func,
out_ind,
*zargs,
Expand All @@ -286,14 +286,14 @@ def blockwise(
plan = Plan._new(
name,
"blockwise",
pipeline.target_array,
pipeline,
op.target_array,
op,
False,
*source_arrays,
)
from cubed.array_api import Array

return Array(name, pipeline.target_array, spec, plan)
return Array(name, op.target_array, spec, plan)


def general_blockwise(
Expand Down Expand Up @@ -322,7 +322,7 @@ def general_blockwise(
spec = check_array_specs(arrays)
if target_store is None:
target_store = new_temp_path(name=name, spec=spec)
pipeline = primitive_general_blockwise(
op = primitive_general_blockwise(
func,
block_function,
*zargs,
Expand All @@ -340,14 +340,14 @@ def general_blockwise(
plan = Plan._new(
name,
"blockwise",
pipeline.target_array,
pipeline,
op.target_array,
op,
False,
*source_arrays,
)
from cubed.array_api import Array

return Array(name, pipeline.target_array, spec, plan)
return Array(name, op.target_array, spec, plan)


def elemwise(func, *args: "Array", dtype=None) -> "Array":
Expand Down Expand Up @@ -706,7 +706,7 @@ def rechunk(x, chunks, target_store=None):
target_store = new_temp_path(name=name, spec=spec)
name_int = f"{name}-int"
temp_store = new_temp_path(name=name_int, spec=spec)
pipelines = primitive_rechunk(
ops = primitive_rechunk(
x.zarray_maybe_lazy,
target_chunks=target_chunks,
allowed_mem=spec.allowed_mem,
Expand All @@ -717,40 +717,40 @@ def rechunk(x, chunks, target_store=None):

from cubed.array_api import Array

if len(pipelines) == 1:
pipeline = pipelines[0]
if len(ops) == 1:
op = ops[0]
plan = Plan._new(
name,
"rechunk",
pipeline.target_array,
pipeline,
op.target_array,
op,
False,
x,
)
return Array(name, pipeline.target_array, spec, plan)
return Array(name, op.target_array, spec, plan)

else:
pipeline1 = pipelines[0]
op1 = ops[0]
plan1 = Plan._new(
name_int,
"rechunk",
pipeline1.target_array,
pipeline1,
op1.target_array,
op1,
False,
x,
)
x_int = Array(name_int, pipeline1.target_array, spec, plan1)
x_int = Array(name_int, op1.target_array, spec, plan1)

pipeline2 = pipelines[1]
op2 = ops[1]
plan2 = Plan._new(
name,
"rechunk",
pipeline2.target_array,
pipeline2,
op2.target_array,
op2,
False,
x_int,
)
return Array(name, pipeline2.target_array, spec, plan2)
return Array(name, op2.target_array, spec, plan2)


def merge_chunks(x, chunks):
Expand Down
48 changes: 27 additions & 21 deletions cubed/core/optimization.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import networkx as nx

from cubed.primitive.blockwise import (
can_fuse_multiple_pipelines,
can_fuse_pipelines,
can_fuse_multiple_primitive_ops,
can_fuse_primitive_ops,
fuse,
fuse_multiple,
)
Expand All @@ -23,8 +23,8 @@ def can_fuse(n):

op2 = n

# if node (op2) does not have a pipeline then it can't be fused
if "pipeline" not in nodes[op2]:
# if node (op2) does not have a primitive op then it can't be fused
if "primitive_op" not in nodes[op2]:
return False

# if node (op2) does not have exactly one input then don't fuse
Expand All @@ -42,10 +42,12 @@ def can_fuse(n):
if dag.out_degree(op1) != 1:
return False

# op1 and op2 must have pipelines that can be fused
if "pipeline" not in nodes[op1]:
# op1 and op2 must have primitive ops that can be fused
if "primitive_op" not in nodes[op1]:
return False
return can_fuse_pipelines(nodes[op1]["pipeline"], nodes[op2]["pipeline"])
return can_fuse_primitive_ops(
nodes[op1]["primitive_op"], nodes[op2]["primitive_op"]
)

for n in list(dag.nodes()):
if can_fuse(n):
Expand All @@ -54,8 +56,9 @@ def can_fuse(n):
op1 = next(dag.predecessors(op2_input))
op1_inputs = list(dag.predecessors(op1))

pipeline = fuse(nodes[op1]["pipeline"], nodes[op2]["pipeline"])
nodes[op2]["pipeline"] = pipeline
primitive_op = fuse(nodes[op1]["primitive_op"], nodes[op2]["primitive_op"])
nodes[op2]["primitive_op"] = primitive_op
nodes[op2]["pipeline"] = primitive_op.pipeline

for n in op1_inputs:
dag.add_edge(n, op2)
Expand Down Expand Up @@ -89,7 +92,7 @@ def predecessor_ops(dag, name):

def is_fusable(node_dict):
"Return True if a node can be fused."
return "pipeline" in node_dict
return "primitive_op" in node_dict


def can_fuse_predecessors(dag, name, *, max_total_nargs=4):
Expand All @@ -113,12 +116,14 @@ def can_fuse_predecessors(dag, name, *, max_total_nargs=4):
if total_nargs > max_total_nargs:
return False

predecessor_pipelines = [
nodes[pre]["pipeline"]
predecessor_primitive_ops = [
nodes[pre]["primitive_op"]
for pre in predecessor_ops(dag, name)
if is_fusable(nodes[pre])
]
return can_fuse_multiple_pipelines(nodes[name]["pipeline"], *predecessor_pipelines)
return can_fuse_multiple_primitive_ops(
nodes[name]["primitive_op"], *predecessor_primitive_ops
)


def fuse_predecessors(dag, name):
Expand All @@ -130,11 +135,11 @@ def fuse_predecessors(dag, name):

nodes = dict(dag.nodes(data=True))

pipeline = nodes[name]["pipeline"]
primitive_op = nodes[name]["primitive_op"]

# if a predecessor op has no pipeline then just use None
predecessor_pipelines = [
nodes[pre]["pipeline"] if is_fusable(nodes[pre]) else None
# if a predecessor op has no primitive op then just use None
predecessor_primitive_ops = [
nodes[pre]["primitive_op"] if is_fusable(nodes[pre]) else None
for pre in predecessor_ops(dag, name)
]

Expand All @@ -144,16 +149,17 @@ def fuse_predecessors(dag, name):
for pre in predecessor_ops(dag, name)
]

fused_pipeline = fuse_multiple(
pipeline,
*predecessor_pipelines,
fused_primitive_op = fuse_multiple(
primitive_op,
*predecessor_primitive_ops,
predecessor_funcs_nargs=predecessor_funcs_nargs,
)

fused_dag = dag.copy()
fused_nodes = dict(fused_dag.nodes(data=True))

fused_nodes[name]["pipeline"] = fused_pipeline
fused_nodes[name]["primitive_op"] = fused_primitive_op
fused_nodes[name]["pipeline"] = fused_primitive_op.pipeline

# re-wire dag to remove predecessor nodes that have been fused

Expand Down
54 changes: 29 additions & 25 deletions cubed/core/plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import zarr

from cubed.core.optimization import simple_optimize_dag
from cubed.primitive.types import PrimitiveOperation
from cubed.runtime.pipeline import visit_nodes
from cubed.runtime.types import CubedPipeline
from cubed.storage.zarr import LazyZarrArray
Expand Down Expand Up @@ -57,7 +58,7 @@ def _new(
name,
op_name,
target,
pipeline=None,
primitive_op=None,
hidden=False,
*source_arrays,
):
Expand All @@ -73,7 +74,7 @@ def _new(

op_name_unique = gensym()

if pipeline is None:
if primitive_op is None:
# op
dag.add_node(
op_name_unique,
Expand Down Expand Up @@ -101,7 +102,8 @@ def _new(
type="op",
stack_summaries=stack_summaries,
hidden=hidden,
pipeline=pipeline,
primitive_op=primitive_op,
pipeline=primitive_op.pipeline,
)
# array (when multiple outputs are supported there could be more than one)
dag.add_node(
Expand Down Expand Up @@ -137,8 +139,8 @@ def _create_lazy_zarr_arrays(self, dag):
lazy_zarr_arrays = []
reserved_mem_values = []
for n, d in dag.nodes(data=True):
if "pipeline" in d and d["pipeline"].reserved_mem is not None:
reserved_mem_values.append(d["pipeline"].reserved_mem)
if "primitive_op" in d and d["primitive_op"].reserved_mem is not None:
reserved_mem_values.append(d["primitive_op"].reserved_mem)
all_pipeline_nodes.append(n)
if "target" in d and isinstance(d["target"], LazyZarrArray):
lazy_zarr_arrays.append(d["target"])
Expand All @@ -149,15 +151,14 @@ def _create_lazy_zarr_arrays(self, dag):
# add new node and edges
name = "create-arrays"
op_name = name
pipeline = create_zarr_arrays(lazy_zarr_arrays, reserved_mem)
primitive_op = create_zarr_arrays(lazy_zarr_arrays, reserved_mem)
dag.add_node(
name,
name=name,
op_name=op_name,
type="op",
pipeline=pipeline,
projected_mem=pipeline.projected_mem,
num_tasks=pipeline.num_tasks,
primitive_op=primitive_op,
pipeline=primitive_op.pipeline,
)
dag.add_node(
"arrays",
Expand Down Expand Up @@ -212,8 +213,7 @@ def num_tasks(self, optimize_graph=True, optimize_function=None, resume=None):
dag = self._finalize_dag(optimize_graph, optimize_function)
tasks = 0
for _, node in visit_nodes(dag, resume=resume):
pipeline = node["pipeline"]
tasks += pipeline.num_tasks
tasks += node["primitive_op"].num_tasks
return tasks

def num_arrays(self, optimize_graph: bool = True, optimize_function=None) -> int:
Expand All @@ -227,7 +227,7 @@ def max_projected_mem(
"""Return the maximum projected memory across all tasks to execute this plan."""
dag = self._finalize_dag(optimize_graph, optimize_function)
projected_mem_values = [
node["pipeline"].projected_mem
node["primitive_op"].projected_mem
for _, node in visit_nodes(dag, resume=resume)
]
return max(projected_mem_values) if len(projected_mem_values) > 0 else 0
Expand Down Expand Up @@ -314,16 +314,18 @@ def visualize(
op_name_summary = ""
tooltip += f"op: {op_name}"

if "pipeline" in d:
pipeline = d["pipeline"]
if "primitive_op" in d:
primitive_op = d["primitive_op"]
tooltip += (
f"\nprojected memory: {memory_repr(pipeline.projected_mem)}"
f"\nprojected memory: {memory_repr(primitive_op.projected_mem)}"
)
tooltip += f"\ntasks: {pipeline.num_tasks}"
if pipeline.write_chunks is not None:
tooltip += f"\nwrite chunks: {pipeline.write_chunks}"
tooltip += f"\ntasks: {primitive_op.num_tasks}"
if primitive_op.write_chunks is not None:
tooltip += f"\nwrite chunks: {primitive_op.write_chunks}"
del d["primitive_op"]

# remove pipeline attribute since it is a long string that causes graphviz to fail
# remove pipeline attribute since it is a long string that causes graphviz to fail
if "pipeline" in d:
del d["pipeline"]

if "stack_summaries" in d and d["stack_summaries"] is not None:
Expand Down Expand Up @@ -433,14 +435,16 @@ def create_zarr_arrays(lazy_zarr_arrays, reserved_mem):
)
num_tasks = len(lazy_zarr_arrays)

return CubedPipeline(
pipeline = CubedPipeline(
create_zarr_array,
"create_zarr_array",
lazy_zarr_arrays,
None,
None,
projected_mem,
reserved_mem,
num_tasks,
None,
)
return PrimitiveOperation(
pipeline=pipeline,
target_array=None,
projected_mem=projected_mem,
reserved_mem=reserved_mem,
num_tasks=num_tasks,
)
8 changes: 4 additions & 4 deletions cubed/extensions/history.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,14 @@ class HistoryCallback(Callback):
def on_compute_start(self, dag, resume):
plan = []
for name, node in visit_nodes(dag, resume):
pipeline = node["pipeline"]
primitive_op = node["primitive_op"]
plan.append(
dict(
array_name=name,
op_name=node["op_name"],
projected_mem=pipeline.projected_mem,
reserved_mem=pipeline.reserved_mem,
num_tasks=pipeline.num_tasks,
projected_mem=primitive_op.projected_mem,
reserved_mem=primitive_op.reserved_mem,
num_tasks=primitive_op.num_tasks,
)
)

Expand Down
2 changes: 1 addition & 1 deletion cubed/extensions/tqdm.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ def on_compute_start(self, dag, resume):
self.pbars = {}
i = 0
for name, node in visit_nodes(dag, resume):
num_tasks = node["pipeline"].num_tasks
num_tasks = node["primitive_op"].num_tasks
self.pbars[name] = tqdm(
*self.args, desc=name, total=num_tasks, position=i, **self.kwargs
)
Expand Down
Loading

0 comments on commit 30df284

Please sign in to comment.