Skip to content

Commit

Permalink
Change the DAG to have separate nodes for operations and arrays (#337)
Browse files Browse the repository at this point in the history
  • Loading branch information
tomwhite authored Jan 3, 2024
1 parent 49f9682 commit f83e556
Show file tree
Hide file tree
Showing 4 changed files with 203 additions and 110 deletions.
253 changes: 166 additions & 87 deletions cubed/core/plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from functools import lru_cache

import networkx as nx
import zarr

from cubed.primitive.blockwise import can_fuse_pipelines, fuse
from cubed.runtime.pipeline import visit_nodes
Expand All @@ -15,6 +16,14 @@
# A unique ID with sensible ordering, used for making directory names
CONTEXT_ID = f"cubed-{datetime.now().strftime('%Y%m%dT%H%M%S')}-{uuid.uuid4()}"

sym_counter = 0


def gensym(name="op"):
global sym_counter
sym_counter += 1
return f"{name}-{sym_counter:03}"


class Plan:
"""Deferred computation plan for a graph of arrays.
Expand All @@ -29,6 +38,12 @@ class Plan:
a function with repeated inputs. For example, consider `equals` where the
two arguments are the same array. We need to keep track of these cases, so
we use a NetworkX `MultiDiGraph` rather than just as `DiGraph`.
Compared to a more traditional DAG representing a computation, in Cubed
nodes are not values that are passed to functions, they are instead
"parallel computations" which are run for their side effects. Data does
not flow through the graph - it is written to external storage (Zarr files)
as the output of one pipeline, then read back as the input to later pipelines.
"""

def __init__(self, dag):
Expand All @@ -55,28 +70,50 @@ def _new(
frame = inspect.currentframe().f_back # go back one in the stack
stack_summaries = extract_stack_summaries(frame, limit=10)

op_name_unique = gensym()

if pipeline is None:
# op
dag.add_node(
name,
name=name,
op_name_unique,
name=op_name_unique,
op_name=op_name,
target=target,
type="op",
stack_summaries=stack_summaries,
hidden=hidden,
)
else:
# array (when multiple outputs are supported there could be more than one)
dag.add_node(
name,
name=name,
op_name=op_name,
type="array",
target=target,
hidden=hidden,
)
dag.add_edge(op_name_unique, name)
else:
# op
dag.add_node(
op_name_unique,
name=op_name_unique,
op_name=op_name,
type="op",
stack_summaries=stack_summaries,
hidden=hidden,
pipeline=pipeline,
)
# array (when multiple outputs are supported there could be more than one)
dag.add_node(
name,
name=name,
type="array",
target=target,
hidden=hidden,
)
dag.add_edge(op_name_unique, name)
for x in source_arrays:
if hasattr(x, "name"):
dag.add_edge(x.name, name)
dag.add_edge(x.name, op_name_unique)

return Plan(dag)

Expand All @@ -93,42 +130,62 @@ def optimize(self):
nodes = {n: d for (n, d) in dag.nodes(data=True)}

def can_fuse(n):
# node must have a single predecessor
# - not multiple edges pointing to a single predecessor
# node must be the single successor to the predecessor
# and both must have pipelines that can be fused
if dag.in_degree(n) != 1:
# fuse a single chain looking like this:
# op1 -> op2_input -> op2

op2 = n

# if node (op2) does not have a pipeline then it can't be fused
if "pipeline" not in nodes[op2]:
return False
pre = next(dag.predecessors(n))
if dag.out_degree(pre) != 1:

# if node (op2) does not have exactly one input then don't fuse
# (it could have no inputs or multiple inputs)
if dag.in_degree(op2) != 1:
return False
if "pipeline" not in nodes[pre] or "pipeline" not in nodes[n]:

# if input is used by another node then don't fuse
op2_input = next(dag.predecessors(op2))
if dag.out_degree(op2_input) != 1:
return False

# if node producing input (op1) has more than one output then don't fuse
op1 = next(dag.predecessors(op2_input))
if dag.out_degree(op1) != 1:
return False
return can_fuse_pipelines(nodes[pre]["pipeline"], nodes[n]["pipeline"])

# op1 and op2 must have pipelines that can be fused
if "pipeline" not in nodes[op1]:
return False
return can_fuse_pipelines(nodes[op1]["pipeline"], nodes[op2]["pipeline"])

for n in list(dag.nodes()):
if can_fuse(n):
pre = next(dag.predecessors(n))
pipeline = fuse(nodes[pre]["pipeline"], nodes[n]["pipeline"])
nodes[n]["pipeline"] = pipeline
assert nodes[n]["target"] == pipeline.target_array
op2 = n
op2_input = next(dag.predecessors(op2))
op1 = next(dag.predecessors(op2_input))
op1_inputs = list(dag.predecessors(op1))

pipeline = fuse(nodes[op1]["pipeline"], nodes[op2]["pipeline"])
nodes[op2]["pipeline"] = pipeline

for p in dag.predecessors(pre):
dag.add_edge(p, n)
dag.remove_node(pre)
for n in op1_inputs:
dag.add_edge(n, op2)
dag.remove_node(op2_input)
dag.remove_node(op1)

return Plan(dag)

def _create_lazy_zarr_arrays(self, dag):
# find all lazy zarr arrays in dag
all_array_nodes = []
all_pipeline_nodes = []
lazy_zarr_arrays = []
reserved_mem_values = []
for n, d in dag.nodes(data=True):
if "pipeline" in d and d["pipeline"].reserved_mem is not None:
reserved_mem_values.append(d["pipeline"].reserved_mem)
if isinstance(d["target"], LazyZarrArray):
all_array_nodes.append(n)
all_pipeline_nodes.append(n)
if "target" in d and isinstance(d["target"], LazyZarrArray):
lazy_zarr_arrays.append(d["target"])

reserved_mem = max(reserved_mem_values, default=0)
Expand All @@ -142,14 +199,20 @@ def _create_lazy_zarr_arrays(self, dag):
name,
name=name,
op_name=op_name,
target=None,
type="op",
pipeline=pipeline,
projected_mem=pipeline.projected_mem,
num_tasks=pipeline.num_tasks,
)
# make create arrays node a dependency of all lazy array nodes
for n in all_array_nodes:
dag.add_edge(name, n)
dag.add_node(
"arrays",
name="arrays",
target=None,
)
dag.add_edge(name, "arrays")
# make create arrays node a predecessor of all pipeline nodes so it runs first
for n in all_pipeline_nodes:
dag.add_edge("arrays", n)

return dag

Expand Down Expand Up @@ -196,7 +259,7 @@ def num_tasks(self, optimize_graph=True, resume=None):
def num_arrays(self, optimize_graph: bool = True) -> int:
"""Return the number of arrays in this plan."""
dag = self._finalize_dag(optimize_graph=optimize_graph)
return sum(n != "create-arrays" for n in dag.nodes())
return sum(d.get("type") == "array" for _, d in dag.nodes(data=True))

def max_projected_mem(self, optimize_graph=True, resume=None):
"""Return the maximum projected memory across all tasks to execute this plan."""
Expand All @@ -213,8 +276,8 @@ def visualize(
dag = self._finalize_dag(optimize_graph=optimize_graph)
dag = dag.copy() # make a copy since we mutate the DAG below

# remove edges from create-arrays node to avoid cluttering the diagram
dag.remove_edges_from(list(dag.out_edges("create-arrays")))
# remove edges from create-arrays output node to avoid cluttering the diagram
dag.remove_edges_from(list(dag.out_edges("arrays")))

# remove hidden nodes
dag.remove_nodes_from(
Expand Down Expand Up @@ -252,69 +315,85 @@ def visualize(

# now set node attributes with visualization info
for n, d in dag.nodes(data=True):
if d["op_name"] == "blockwise":
d["style"] = "filled"
d["fillcolor"] = "#dcbeff"
op_name_summary = "(bw)"
elif d["op_name"] == "rechunk":
d["style"] = "filled"
d["fillcolor"] = "#aaffc3"
op_name_summary = "(rc)"
else: # creation function
op_name_summary = ""
target = d["target"]
if target is not None:
tooltip = f"name: {n}\n"
node_type = d.get("type", None)
if node_type == "op":
op_name = d["op_name"]
if op_name == "blockwise":
d["style"] = '"rounded,filled"'
d["fillcolor"] = "#dcbeff"
op_name_summary = "(bw)"
elif op_name == "rechunk":
d["style"] = '"rounded,filled"'
d["fillcolor"] = "#aaffc3"
op_name_summary = "(rc)"
else:
# creation function
d["style"] = "rounded"
op_name_summary = ""
tooltip += f"op: {op_name}"

if "pipeline" in d:
pipeline = d["pipeline"]
tooltip += (
f"\nprojected memory: {memory_repr(pipeline.projected_mem)}"
)
tooltip += f"\ntasks: {pipeline.num_tasks}"
if pipeline.write_chunks is not None:
tooltip += f"\nwrite chunks: {pipeline.write_chunks}"

# remove pipeline attribute since it is a long string that causes graphviz to fail
del d["pipeline"]

if "stack_summaries" in d and d["stack_summaries"] is not None:
# add call stack information
stack_summaries = d["stack_summaries"]

first_cubed_i = min(
i for i, s in enumerate(stack_summaries) if s.is_cubed()
)
first_cubed_summary = stack_summaries[first_cubed_i]
caller_summary = stack_summaries[first_cubed_i - 1]

d["label"] = f"{first_cubed_summary.name} {op_name_summary}"

calls = " -> ".join(
[
s.name
for s in stack_summaries
if not s.is_on_python_lib_path()
]
)

line = f"{caller_summary.lineno} in {caller_summary.name}"

tooltip += f"\ncalls: {calls}"
tooltip += f"\nline: {line}"
del d["stack_summaries"]

elif node_type == "array":
target = d["target"]
chunkmem = memory_repr(chunk_memory(target.dtype, target.chunks))
tooltip = (
f"name: {n}\n"
f"shape: {target.shape}\n"
f"chunks: {target.chunks}\n"
f"dtype: {target.dtype}\n"
f"chunk memory: {chunkmem}\n"
)
else:
tooltip = ""
if "pipeline" in d:
pipeline = d["pipeline"]
tooltip += f"\nprojected memory: {memory_repr(pipeline.projected_mem)}"
tooltip += f"\ntasks: {pipeline.num_tasks}"
if pipeline.write_chunks is not None:
tooltip += f"\nwrite chunks: {pipeline.write_chunks}"
if "stack_summaries" in d and d["stack_summaries"] is not None:
# add call stack information
stack_summaries = d["stack_summaries"]

first_cubed_i = min(
i for i, s in enumerate(stack_summaries) if s.is_cubed()
)
first_cubed_summary = stack_summaries[first_cubed_i]
caller_summary = stack_summaries[first_cubed_i - 1]

# materialized arrays are light orange, virtual arrays are white
if isinstance(target, (LazyZarrArray, zarr.Array)):
d["style"] = "filled"
d["fillcolor"] = "#ffd8b1"
if n in array_display_names:
var_name = f" ({array_display_names[n]})"
var_name = array_display_names[n]
d["label"] = f"{n} ({var_name})"
tooltip += f"variable: {var_name}\n"
else:
var_name = ""
d[
"label"
] = f"{n}{var_name}\n{first_cubed_summary.name} {op_name_summary}"
d["label"] = n
tooltip += f"shape: {target.shape}\n"
tooltip += f"chunks: {target.chunks}\n"
tooltip += f"dtype: {target.dtype}\n"
tooltip += f"chunk memory: {chunkmem}\n"

calls = " -> ".join(
[s.name for s in stack_summaries if not s.is_on_python_lib_path()]
)

line = f"{caller_summary.lineno} in {caller_summary.name}"

tooltip += f"\ncalls: {calls}"
tooltip += f"\nline: {line}"
del d["stack_summaries"]
del d["target"]

d["tooltip"] = tooltip.strip()

# remove pipeline attribute since it is a long string that causes graphviz to fail
if "pipeline" in d:
del d["pipeline"]
if "target" in d:
del d["target"]
if "name" in d: # pydot already has name
del d["name"]
gv = nx.drawing.nx_pydot.to_pydot(dag)
Expand Down
Loading

0 comments on commit f83e556

Please sign in to comment.