Skip to content

Commit

Permalink
Refactor pipeline to reduce coupling between runtime and primitive la…
Browse files Browse the repository at this point in the history
…yers

Introduce PrimitiveOperation that encapsulates information for blockwise and rechunk that the runtime layer does not need
  • Loading branch information
tomwhite committed Jan 18, 2024
1 parent 670ea0c commit a334cbe
Show file tree
Hide file tree
Showing 11 changed files with 226 additions and 182 deletions.
44 changes: 22 additions & 22 deletions cubed/core/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,7 +266,7 @@ def blockwise(
spec = check_array_specs(arrays)
if target_store is None:
target_store = new_temp_path(name=name, spec=spec)
pipeline = primitive_blockwise(
op = primitive_blockwise(
func,
out_ind,
*zargs,
Expand All @@ -286,14 +286,14 @@ def blockwise(
plan = Plan._new(
name,
"blockwise",
pipeline.target_array,
pipeline,
op.target_array,
op,
False,
*source_arrays,
)
from cubed.array_api import Array

return Array(name, pipeline.target_array, spec, plan)
return Array(name, op.target_array, spec, plan)


def general_blockwise(
Expand Down Expand Up @@ -322,7 +322,7 @@ def general_blockwise(
spec = check_array_specs(arrays)
if target_store is None:
target_store = new_temp_path(name=name, spec=spec)
pipeline = primitive_general_blockwise(
op = primitive_general_blockwise(
func,
block_function,
*zargs,
Expand All @@ -340,14 +340,14 @@ def general_blockwise(
plan = Plan._new(
name,
"blockwise",
pipeline.target_array,
pipeline,
op.target_array,
op,
False,
*source_arrays,
)
from cubed.array_api import Array

return Array(name, pipeline.target_array, spec, plan)
return Array(name, op.target_array, spec, plan)


def elemwise(func, *args: "Array", dtype=None) -> "Array":
Expand Down Expand Up @@ -706,7 +706,7 @@ def rechunk(x, chunks, target_store=None):
target_store = new_temp_path(name=name, spec=spec)
name_int = f"{name}-int"
temp_store = new_temp_path(name=name_int, spec=spec)
pipelines = primitive_rechunk(
ops = primitive_rechunk(
x.zarray_maybe_lazy,
target_chunks=target_chunks,
allowed_mem=spec.allowed_mem,
Expand All @@ -717,40 +717,40 @@ def rechunk(x, chunks, target_store=None):

from cubed.array_api import Array

if len(pipelines) == 1:
pipeline = pipelines[0]
if len(ops) == 1:
op = ops[0]
plan = Plan._new(
name,
"rechunk",
pipeline.target_array,
pipeline,
op.target_array,
op,
False,
x,
)
return Array(name, pipeline.target_array, spec, plan)
return Array(name, op.target_array, spec, plan)

else:
pipeline1 = pipelines[0]
op1 = ops[0]
plan1 = Plan._new(
name_int,
"rechunk",
pipeline1.target_array,
pipeline1,
op1.target_array,
op1,
False,
x,
)
x_int = Array(name_int, pipeline1.target_array, spec, plan1)
x_int = Array(name_int, op1.target_array, spec, plan1)

pipeline2 = pipelines[1]
op2 = ops[1]
plan2 = Plan._new(
name,
"rechunk",
pipeline2.target_array,
pipeline2,
op2.target_array,
op2,
False,
x_int,
)
return Array(name, pipeline2.target_array, spec, plan2)
return Array(name, op2.target_array, spec, plan2)


def merge_chunks(x, chunks):
Expand Down
48 changes: 27 additions & 21 deletions cubed/core/optimization.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import networkx as nx

from cubed.primitive.blockwise import (
can_fuse_multiple_pipelines,
can_fuse_pipelines,
can_fuse_multiple_primitive_ops,
can_fuse_primitive_ops,
fuse,
fuse_multiple,
)
Expand All @@ -23,8 +23,8 @@ def can_fuse(n):

op2 = n

# if node (op2) does not have a pipeline then it can't be fused
if "pipeline" not in nodes[op2]:
# if node (op2) does not have a primitive op then it can't be fused
if "primitive_op" not in nodes[op2]:
return False

# if node (op2) does not have exactly one input then don't fuse
Expand All @@ -42,10 +42,12 @@ def can_fuse(n):
if dag.out_degree(op1) != 1:
return False

# op1 and op2 must have pipelines that can be fused
if "pipeline" not in nodes[op1]:
# op1 and op2 must have primitive ops that can be fused
if "primitive_op" not in nodes[op1]:
return False
return can_fuse_pipelines(nodes[op1]["pipeline"], nodes[op2]["pipeline"])
return can_fuse_primitive_ops(
nodes[op1]["primitive_op"], nodes[op2]["primitive_op"]
)

for n in list(dag.nodes()):
if can_fuse(n):
Expand All @@ -54,8 +56,9 @@ def can_fuse(n):
op1 = next(dag.predecessors(op2_input))
op1_inputs = list(dag.predecessors(op1))

pipeline = fuse(nodes[op1]["pipeline"], nodes[op2]["pipeline"])
nodes[op2]["pipeline"] = pipeline
primitive_op = fuse(nodes[op1]["primitive_op"], nodes[op2]["primitive_op"])
nodes[op2]["primitive_op"] = primitive_op
nodes[op2]["pipeline"] = primitive_op.pipeline

for n in op1_inputs:
dag.add_edge(n, op2)
Expand Down Expand Up @@ -89,7 +92,7 @@ def predecessor_ops(dag, name):

def is_fusable(node_dict):
"Return True if a node can be fused."
return "pipeline" in node_dict
return "primitive_op" in node_dict


def can_fuse_predecessors(dag, name, *, max_total_nargs=4):
Expand All @@ -113,12 +116,14 @@ def can_fuse_predecessors(dag, name, *, max_total_nargs=4):
if total_nargs > max_total_nargs:
return False

predecessor_pipelines = [
nodes[pre]["pipeline"]
predecessor_primitive_ops = [
nodes[pre]["primitive_op"]
for pre in predecessor_ops(dag, name)
if is_fusable(nodes[pre])
]
return can_fuse_multiple_pipelines(nodes[name]["pipeline"], *predecessor_pipelines)
return can_fuse_multiple_primitive_ops(
nodes[name]["primitive_op"], *predecessor_primitive_ops
)


def fuse_predecessors(dag, name):
Expand All @@ -130,11 +135,11 @@ def fuse_predecessors(dag, name):

nodes = dict(dag.nodes(data=True))

pipeline = nodes[name]["pipeline"]
primitive_op = nodes[name]["primitive_op"]

# if a predecessor op has no pipeline then just use None
predecessor_pipelines = [
nodes[pre]["pipeline"] if is_fusable(nodes[pre]) else None
# if a predecessor op has no primitive op then just use None
predecessor_primitive_ops = [
nodes[pre]["primitive_op"] if is_fusable(nodes[pre]) else None
for pre in predecessor_ops(dag, name)
]

Expand All @@ -144,16 +149,17 @@ def fuse_predecessors(dag, name):
for pre in predecessor_ops(dag, name)
]

fused_pipeline = fuse_multiple(
pipeline,
*predecessor_pipelines,
fused_primitive_op = fuse_multiple(
primitive_op,
*predecessor_primitive_ops,
predecessor_funcs_nargs=predecessor_funcs_nargs,
)

fused_dag = dag.copy()
fused_nodes = dict(fused_dag.nodes(data=True))

fused_nodes[name]["pipeline"] = fused_pipeline
fused_nodes[name]["primitive_op"] = fused_primitive_op
fused_nodes[name]["pipeline"] = fused_primitive_op.pipeline

# re-wire dag to remove predecessor nodes that have been fused

Expand Down
54 changes: 29 additions & 25 deletions cubed/core/plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import zarr

from cubed.core.optimization import simple_optimize_dag
from cubed.primitive.types import PrimitiveOperation
from cubed.runtime.pipeline import visit_nodes
from cubed.runtime.types import CubedPipeline
from cubed.storage.zarr import LazyZarrArray
Expand Down Expand Up @@ -57,7 +58,7 @@ def _new(
name,
op_name,
target,
pipeline=None,
primitive_op=None,
hidden=False,
*source_arrays,
):
Expand All @@ -73,7 +74,7 @@ def _new(

op_name_unique = gensym()

if pipeline is None:
if primitive_op is None:
# op
dag.add_node(
op_name_unique,
Expand Down Expand Up @@ -101,7 +102,8 @@ def _new(
type="op",
stack_summaries=stack_summaries,
hidden=hidden,
pipeline=pipeline,
primitive_op=primitive_op,
pipeline=primitive_op.pipeline,
)
# array (when multiple outputs are supported there could be more than one)
dag.add_node(
Expand Down Expand Up @@ -137,8 +139,8 @@ def _create_lazy_zarr_arrays(self, dag):
lazy_zarr_arrays = []
reserved_mem_values = []
for n, d in dag.nodes(data=True):
if "pipeline" in d and d["pipeline"].reserved_mem is not None:
reserved_mem_values.append(d["pipeline"].reserved_mem)
if "primitive_op" in d and d["primitive_op"].reserved_mem is not None:
reserved_mem_values.append(d["primitive_op"].reserved_mem)
all_pipeline_nodes.append(n)
if "target" in d and isinstance(d["target"], LazyZarrArray):
lazy_zarr_arrays.append(d["target"])
Expand All @@ -149,15 +151,14 @@ def _create_lazy_zarr_arrays(self, dag):
# add new node and edges
name = "create-arrays"
op_name = name
pipeline = create_zarr_arrays(lazy_zarr_arrays, reserved_mem)
primitive_op = create_zarr_arrays(lazy_zarr_arrays, reserved_mem)
dag.add_node(
name,
name=name,
op_name=op_name,
type="op",
pipeline=pipeline,
projected_mem=pipeline.projected_mem,
num_tasks=pipeline.num_tasks,
primitive_op=primitive_op,
pipeline=primitive_op.pipeline,
)
dag.add_node(
"arrays",
Expand Down Expand Up @@ -212,8 +213,7 @@ def num_tasks(self, optimize_graph=True, optimize_function=None, resume=None):
dag = self._finalize_dag(optimize_graph, optimize_function)
tasks = 0
for _, node in visit_nodes(dag, resume=resume):
pipeline = node["pipeline"]
tasks += pipeline.num_tasks
tasks += node["primitive_op"].num_tasks
return tasks

def num_arrays(self, optimize_graph: bool = True, optimize_function=None) -> int:
Expand All @@ -227,7 +227,7 @@ def max_projected_mem(
"""Return the maximum projected memory across all tasks to execute this plan."""
dag = self._finalize_dag(optimize_graph, optimize_function)
projected_mem_values = [
node["pipeline"].projected_mem
node["primitive_op"].projected_mem
for _, node in visit_nodes(dag, resume=resume)
]
return max(projected_mem_values) if len(projected_mem_values) > 0 else 0
Expand Down Expand Up @@ -314,16 +314,18 @@ def visualize(
op_name_summary = ""
tooltip += f"op: {op_name}"

if "pipeline" in d:
pipeline = d["pipeline"]
if "primitive_op" in d:
primitive_op = d["primitive_op"]
tooltip += (
f"\nprojected memory: {memory_repr(pipeline.projected_mem)}"
f"\nprojected memory: {memory_repr(primitive_op.projected_mem)}"
)
tooltip += f"\ntasks: {pipeline.num_tasks}"
if pipeline.write_chunks is not None:
tooltip += f"\nwrite chunks: {pipeline.write_chunks}"
tooltip += f"\ntasks: {primitive_op.num_tasks}"
if primitive_op.write_chunks is not None:
tooltip += f"\nwrite chunks: {primitive_op.write_chunks}"
del d["primitive_op"]

# remove pipeline attribute since it is a long string that causes graphviz to fail
# remove pipeline attribute since it is a long string that causes graphviz to fail
if "pipeline" in d:
del d["pipeline"]

if "stack_summaries" in d and d["stack_summaries"] is not None:
Expand Down Expand Up @@ -433,14 +435,16 @@ def create_zarr_arrays(lazy_zarr_arrays, reserved_mem):
)
num_tasks = len(lazy_zarr_arrays)

return CubedPipeline(
pipeline = CubedPipeline(
create_zarr_array,
"create_zarr_array",
lazy_zarr_arrays,
None,
None,
projected_mem,
reserved_mem,
num_tasks,
None,
)
return PrimitiveOperation(
pipeline=pipeline,
target_array=None,
projected_mem=projected_mem,
reserved_mem=reserved_mem,
num_tasks=num_tasks,
)
8 changes: 4 additions & 4 deletions cubed/extensions/history.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,14 @@ class HistoryCallback(Callback):
def on_compute_start(self, dag, resume):
plan = []
for name, node in visit_nodes(dag, resume):
pipeline = node["pipeline"]
primitive_op = node["primitive_op"]
plan.append(
dict(
array_name=name,
op_name=node["op_name"],
projected_mem=pipeline.projected_mem,
reserved_mem=pipeline.reserved_mem,
num_tasks=pipeline.num_tasks,
projected_mem=primitive_op.projected_mem,
reserved_mem=primitive_op.reserved_mem,
num_tasks=primitive_op.num_tasks,
)
)

Expand Down
2 changes: 1 addition & 1 deletion cubed/extensions/tqdm.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ def on_compute_start(self, dag, resume):
self.pbars = {}
i = 0
for name, node in visit_nodes(dag, resume):
num_tasks = node["pipeline"].num_tasks
num_tasks = node["primitive_op"].num_tasks
self.pbars[name] = tqdm(
*self.args, desc=name, total=num_tasks, position=i, **self.kwargs
)
Expand Down
Loading

0 comments on commit a334cbe

Please sign in to comment.