Skip to content

Commit

Permalink
Change total_nbytes to total_nbytes_written
Browse files Browse the repository at this point in the history
This is to avoid reporting the size of Zarr arrays being read from
(using `from_zarr`) in the intermediate `total_nbytes`.
  • Loading branch information
tomwhite committed Mar 19, 2024
1 parent e185fca commit 594e137
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 6 deletions.
10 changes: 6 additions & 4 deletions cubed/core/plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,14 +245,16 @@ def max_projected_mem(
]
return max(projected_mem_values) if len(projected_mem_values) > 0 else 0

def total_nbytes(self, optimize_graph: bool = True, optimize_function=None) -> int:
"""Return the total number of bytes for all materialized arrays in this plan."""
def total_nbytes_written(
self, optimize_graph: bool = True, optimize_function=None
) -> int:
"""Return the total number of bytes written for all materialized arrays in this plan."""
dag = self._finalize_dag(optimize_graph, optimize_function)
nbytes = 0
for _, d in dag.nodes(data=True):
if d.get("type") == "array":
target = d["target"]
if isinstance(target, (LazyZarrArray, zarr.Array)):
if isinstance(target, LazyZarrArray):
nbytes += target.nbytes
return nbytes

Expand Down Expand Up @@ -282,7 +284,7 @@ def visualize(
# note that \l is used to left-justify each line (see https://www.graphviz.org/docs/attrs/nojustify/)
rf"num tasks: {self.num_tasks(optimize_graph, optimize_function)}\l"
rf"max projected memory: {memory_repr(self.max_projected_mem(optimize_graph, optimize_function))}\l"
rf"total nbytes: {memory_repr(self.total_nbytes(optimize_graph, optimize_function))}\l"
rf"total nbytes written: {memory_repr(self.total_nbytes_written(optimize_graph, optimize_function))}\l"
rf"optimized: {optimize_graph}\l"
),
"labelloc": "bottom",
Expand Down
7 changes: 5 additions & 2 deletions cubed/tests/test_optimization.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,12 +37,15 @@ def test_fusion(spec):
num_created_arrays = 3 # b, c, d (a is not created on disk)
assert d.plan.num_arrays(optimize_graph=False) == num_arrays
assert d.plan.num_tasks(optimize_graph=False) == num_created_arrays + 12
assert d.plan.total_nbytes(optimize_graph=False) == b.nbytes + c.nbytes + d.nbytes
assert (
d.plan.total_nbytes_written(optimize_graph=False)
== b.nbytes + c.nbytes + d.nbytes
)
num_arrays = 2 # a, d
num_created_arrays = 1 # d (a is not created on disk)
assert d.plan.num_arrays(optimize_graph=True) == num_arrays
assert d.plan.num_tasks(optimize_graph=True) == num_created_arrays + 4
assert d.plan.total_nbytes(optimize_graph=True) == d.nbytes
assert d.plan.total_nbytes_written(optimize_graph=True) == d.nbytes

task_counter = TaskCounter()
result = d.compute(callbacks=[task_counter])
Expand Down

0 comments on commit 594e137

Please sign in to comment.