Skip to content

Commit

Permalink
Show both operation ID and name (e.g. "op-005 add") in plan visualiza…
Browse files Browse the repository at this point in the history
…tion and progress bars. (#396)

Change plan visualization to always show array or op ID at top.
  • Loading branch information
tomwhite authored Feb 27, 2024
1 parent 0ad9ea3 commit 91c5cd1
Show file tree
Hide file tree
Showing 4 changed files with 21 additions and 11 deletions.
18 changes: 10 additions & 8 deletions cubed/core/plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import zarr

from cubed.core.optimization import simple_optimize_dag
from cubed.primitive.blockwise import BlockwiseSpec
from cubed.primitive.types import PrimitiveOperation
from cubed.runtime.pipeline import visit_nodes
from cubed.runtime.types import ComputeEndEvent, ComputeStartEvent, CubedPipeline
Expand Down Expand Up @@ -72,6 +73,9 @@ def _new(
frame = inspect.currentframe().f_back # go back one in the stack
stack_summaries = extract_stack_summaries(frame, limit=10)

first_cubed_i = min(i for i, s in enumerate(stack_summaries) if s.is_cubed())
first_cubed_summary = stack_summaries[first_cubed_i]

op_name_unique = gensym()

if primitive_op is None:
Expand All @@ -82,6 +86,7 @@ def _new(
op_name=op_name,
type="op",
stack_summaries=stack_summaries,
op_display_name=f"{op_name_unique}\n{first_cubed_summary.name}",
hidden=hidden,
)
# array (when multiple outputs are supported there could be more than one)
Expand All @@ -101,6 +106,7 @@ def _new(
op_name=op_name,
type="op",
stack_summaries=stack_summaries,
op_display_name=f"{op_name_unique}\n{first_cubed_summary.name}",
hidden=hidden,
primitive_op=primitive_op,
pipeline=primitive_op.pipeline,
Expand Down Expand Up @@ -160,6 +166,7 @@ def _create_lazy_zarr_arrays(self, dag):
name=name,
op_name=op_name,
type="op",
op_display_name=name,
primitive_op=primitive_op,
pipeline=primitive_op.pipeline,
)
Expand Down Expand Up @@ -307,19 +314,17 @@ def visualize(
tooltip = f"name: {n}\n"
node_type = d.get("type", None)
if node_type == "op":
label = d["op_display_name"]
op_name = d["op_name"]
if op_name == "blockwise":
d["style"] = '"rounded,filled"'
d["fillcolor"] = "#dcbeff"
op_name_summary = "(bw)"
elif op_name == "rechunk":
d["style"] = '"rounded,filled"'
d["fillcolor"] = "#aaffc3"
op_name_summary = "(rc)"
else:
# creation function
d["style"] = "rounded"
op_name_summary = ""
tooltip += f"op: {op_name}"

num_tasks = None
Expand All @@ -337,7 +342,7 @@ def visualize(
# remove pipeline attribute since it is a long string that causes graphviz to fail
if "pipeline" in d:
pipeline = d["pipeline"]
if pipeline.config is not None:
if isinstance(pipeline.config, BlockwiseSpec):
tooltip += (
f"\nnum input blocks: {pipeline.config.num_input_blocks}"
)
Expand All @@ -350,11 +355,8 @@ def visualize(
first_cubed_i = min(
i for i, s in enumerate(stack_summaries) if s.is_cubed()
)
first_cubed_summary = stack_summaries[first_cubed_i]
caller_summary = stack_summaries[first_cubed_i - 1]

label = f"{first_cubed_summary.name} {op_name_summary}"

calls = " -> ".join(
[
s.name
Expand Down Expand Up @@ -384,7 +386,7 @@ def visualize(
nbytes = memory_repr(target.nbytes)
if n in array_display_names:
var_name = array_display_names[n]
label = f"{n} ({var_name})"
label = f"{n}\n{var_name}"
tooltip += f"variable: {var_name}\n"
tooltip += f"shape: {target.shape}\n"
tooltip += f"chunks: {target.chunks}\n"
Expand Down
5 changes: 4 additions & 1 deletion cubed/extensions/rich.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,10 @@ def on_compute_start(self, event):
progress_tasks = {}
for name, node in visit_nodes(event.dag, event.resume):
num_tasks = node["primitive_op"].num_tasks
progress_task = progress.add_task(f"{name}", start=False, total=num_tasks)
op_display_name = node["op_display_name"].replace("\n", " ")
progress_task = progress.add_task(
f"{op_display_name}", start=False, total=num_tasks
)
progress_tasks[name] = progress_task

self.logger_aware_progress = logger_aware_progress
Expand Down
7 changes: 6 additions & 1 deletion cubed/extensions/tqdm.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,13 @@ def on_compute_start(self, event):
i = 0
for name, node in visit_nodes(event.dag, event.resume):
num_tasks = node["primitive_op"].num_tasks
op_display_name = node["op_display_name"].replace("\n", " ")
self.pbars[name] = tqdm(
*self.args, desc=name, total=num_tasks, position=i, **self.kwargs
*self.args,
desc=op_display_name,
total=num_tasks,
position=i,
**self.kwargs,
)
i = i + 1

Expand Down
2 changes: 1 addition & 1 deletion cubed/tests/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -457,7 +457,7 @@ def test_visualize(tmp_path):
a = xp.asarray([[1, 2, 3], [4, 5, 6], [7, 8, 9]], dtype=xp.float64, chunks=(2, 2))
b = cubed.random.random((3, 3), chunks=(2, 2))
c = xp.add(a, b)
d = c * 2
d = c.rechunk((3, 1))
e = c * 3

f = xp.asarray([[1, 2, 3], [4, 5, 6], [7, 8, 9]], chunks=(2, 2))
Expand Down

0 comments on commit 91c5cd1

Please sign in to comment.