From a334cbecb12abd55801513398c1c2f68db40e784 Mon Sep 17 00:00:00 2001 From: Tom White Date: Wed, 17 Jan 2024 13:11:09 +0000 Subject: [PATCH] Refactor pipeline to reduce coupling between runtime and primitive layers Introduce PrimitiveOperation that encapsulates information for blockwise and rechunk that the runtime layer does not need --- cubed/core/ops.py | 44 ++++----- cubed/core/optimization.py | 48 +++++---- cubed/core/plan.py | 54 ++++++----- cubed/extensions/history.py | 8 +- cubed/extensions/tqdm.py | 2 +- cubed/primitive/blockwise.py | 123 ++++++++++++++---------- cubed/primitive/rechunk.py | 35 ++++--- cubed/primitive/types.py | 14 +++ cubed/runtime/types.py | 8 +- cubed/tests/primitive/test_blockwise.py | 46 ++++----- cubed/tests/primitive/test_rechunk.py | 26 ++--- 11 files changed, 226 insertions(+), 182 deletions(-) diff --git a/cubed/core/ops.py b/cubed/core/ops.py index b7c1c6109..2bffc2d35 100644 --- a/cubed/core/ops.py +++ b/cubed/core/ops.py @@ -266,7 +266,7 @@ def blockwise( spec = check_array_specs(arrays) if target_store is None: target_store = new_temp_path(name=name, spec=spec) - pipeline = primitive_blockwise( + op = primitive_blockwise( func, out_ind, *zargs, @@ -286,14 +286,14 @@ def blockwise( plan = Plan._new( name, "blockwise", - pipeline.target_array, - pipeline, + op.target_array, + op, False, *source_arrays, ) from cubed.array_api import Array - return Array(name, pipeline.target_array, spec, plan) + return Array(name, op.target_array, spec, plan) def general_blockwise( @@ -322,7 +322,7 @@ def general_blockwise( spec = check_array_specs(arrays) if target_store is None: target_store = new_temp_path(name=name, spec=spec) - pipeline = primitive_general_blockwise( + op = primitive_general_blockwise( func, block_function, *zargs, @@ -340,14 +340,14 @@ def general_blockwise( plan = Plan._new( name, "blockwise", - pipeline.target_array, - pipeline, + op.target_array, + op, False, *source_arrays, ) from cubed.array_api import Array - return Array(name, pipeline.target_array, spec, plan) + return Array(name, op.target_array, spec, plan) def elemwise(func, *args: "Array", dtype=None) -> "Array": @@ -706,7 +706,7 @@ def rechunk(x, chunks, target_store=None): target_store = new_temp_path(name=name, spec=spec) name_int = f"{name}-int" temp_store = new_temp_path(name=name_int, spec=spec) - pipelines = primitive_rechunk( + ops = primitive_rechunk( x.zarray_maybe_lazy, target_chunks=target_chunks, allowed_mem=spec.allowed_mem, @@ -717,40 +717,40 @@ def rechunk(x, chunks, target_store=None): from cubed.array_api import Array - if len(pipelines) == 1: - pipeline = pipelines[0] + if len(ops) == 1: + op = ops[0] plan = Plan._new( name, "rechunk", - pipeline.target_array, - pipeline, + op.target_array, + op, False, x, ) - return Array(name, pipeline.target_array, spec, plan) + return Array(name, op.target_array, spec, plan) else: - pipeline1 = pipelines[0] + op1 = ops[0] plan1 = Plan._new( name_int, "rechunk", - pipeline1.target_array, - pipeline1, + op1.target_array, + op1, False, x, ) - x_int = Array(name_int, pipeline1.target_array, spec, plan1) + x_int = Array(name_int, op1.target_array, spec, plan1) - pipeline2 = pipelines[1] + op2 = ops[1] plan2 = Plan._new( name, "rechunk", - pipeline2.target_array, - pipeline2, + op2.target_array, + op2, False, x_int, ) - return Array(name, pipeline2.target_array, spec, plan2) + return Array(name, op2.target_array, spec, plan2) def merge_chunks(x, chunks): diff --git a/cubed/core/optimization.py b/cubed/core/optimization.py index d89320054..52acac4cd 100644 --- a/cubed/core/optimization.py +++ b/cubed/core/optimization.py @@ -1,8 +1,8 @@ import networkx as nx from cubed.primitive.blockwise import ( - can_fuse_multiple_pipelines, - can_fuse_pipelines, + can_fuse_multiple_primitive_ops, + can_fuse_primitive_ops, fuse, fuse_multiple, ) @@ -23,8 +23,8 @@ def can_fuse(n): op2 = n - # if node (op2) does not have a pipeline then it can't be fused - if "pipeline" not in nodes[op2]: + # if node (op2) does not have a primitive op then it can't be fused + if "primitive_op" not in nodes[op2]: return False # if node (op2) does not have exactly one input then don't fuse @@ -42,10 +42,12 @@ def can_fuse(n): if dag.out_degree(op1) != 1: return False - # op1 and op2 must have pipelines that can be fused - if "pipeline" not in nodes[op1]: + # op1 and op2 must have primitive ops that can be fused + if "primitive_op" not in nodes[op1]: return False - return can_fuse_pipelines(nodes[op1]["pipeline"], nodes[op2]["pipeline"]) + return can_fuse_primitive_ops( + nodes[op1]["primitive_op"], nodes[op2]["primitive_op"] + ) for n in list(dag.nodes()): if can_fuse(n): @@ -54,8 +56,9 @@ def can_fuse(n): op1 = next(dag.predecessors(op2_input)) op1_inputs = list(dag.predecessors(op1)) - pipeline = fuse(nodes[op1]["pipeline"], nodes[op2]["pipeline"]) - nodes[op2]["pipeline"] = pipeline + primitive_op = fuse(nodes[op1]["primitive_op"], nodes[op2]["primitive_op"]) + nodes[op2]["primitive_op"] = primitive_op + nodes[op2]["pipeline"] = primitive_op.pipeline for n in op1_inputs: dag.add_edge(n, op2) @@ -89,7 +92,7 @@ def predecessor_ops(dag, name): def is_fusable(node_dict): "Return True if a node can be fused." - return "pipeline" in node_dict + return "primitive_op" in node_dict def can_fuse_predecessors(dag, name, *, max_total_nargs=4): @@ -113,12 +116,14 @@ def can_fuse_predecessors(dag, name, *, max_total_nargs=4): if total_nargs > max_total_nargs: return False - predecessor_pipelines = [ - nodes[pre]["pipeline"] + predecessor_primitive_ops = [ + nodes[pre]["primitive_op"] for pre in predecessor_ops(dag, name) if is_fusable(nodes[pre]) ] - return can_fuse_multiple_pipelines(nodes[name]["pipeline"], *predecessor_pipelines) + return can_fuse_multiple_primitive_ops( + nodes[name]["primitive_op"], *predecessor_primitive_ops + ) def fuse_predecessors(dag, name): @@ -130,11 +135,11 @@ def fuse_predecessors(dag, name): nodes = dict(dag.nodes(data=True)) - pipeline = nodes[name]["pipeline"] + primitive_op = nodes[name]["primitive_op"] - # if a predecessor op has no pipeline then just use None - predecessor_pipelines = [ - nodes[pre]["pipeline"] if is_fusable(nodes[pre]) else None + # if a predecessor op has no primitive op then just use None + predecessor_primitive_ops = [ + nodes[pre]["primitive_op"] if is_fusable(nodes[pre]) else None for pre in predecessor_ops(dag, name) ] @@ -144,16 +149,17 @@ def fuse_predecessors(dag, name): for pre in predecessor_ops(dag, name) ] - fused_pipeline = fuse_multiple( - pipeline, - *predecessor_pipelines, + fused_primitive_op = fuse_multiple( + primitive_op, + *predecessor_primitive_ops, predecessor_funcs_nargs=predecessor_funcs_nargs, ) fused_dag = dag.copy() fused_nodes = dict(fused_dag.nodes(data=True)) - fused_nodes[name]["pipeline"] = fused_pipeline + fused_nodes[name]["primitive_op"] = fused_primitive_op + fused_nodes[name]["pipeline"] = fused_primitive_op.pipeline # re-wire dag to remove predecessor nodes that have been fused diff --git a/cubed/core/plan.py b/cubed/core/plan.py index d97b5fac3..6043371fc 100644 --- a/cubed/core/plan.py +++ b/cubed/core/plan.py @@ -9,6 +9,7 @@ import zarr from cubed.core.optimization import simple_optimize_dag +from cubed.primitive.types import PrimitiveOperation from cubed.runtime.pipeline import visit_nodes from cubed.runtime.types import CubedPipeline from cubed.storage.zarr import LazyZarrArray @@ -57,7 +58,7 @@ def _new( name, op_name, target, - pipeline=None, + primitive_op=None, hidden=False, *source_arrays, ): @@ -73,7 +74,7 @@ def _new( op_name_unique = gensym() - if pipeline is None: + if primitive_op is None: # op dag.add_node( op_name_unique, @@ -101,7 +102,8 @@ def _new( type="op", stack_summaries=stack_summaries, hidden=hidden, - pipeline=pipeline, + primitive_op=primitive_op, + pipeline=primitive_op.pipeline, ) # array (when multiple outputs are supported there could be more than one) dag.add_node( @@ -137,8 +139,8 @@ def _create_lazy_zarr_arrays(self, dag): lazy_zarr_arrays = [] reserved_mem_values = [] for n, d in dag.nodes(data=True): - if "pipeline" in d and d["pipeline"].reserved_mem is not None: - reserved_mem_values.append(d["pipeline"].reserved_mem) + if "primitive_op" in d and d["primitive_op"].reserved_mem is not None: + reserved_mem_values.append(d["primitive_op"].reserved_mem) all_pipeline_nodes.append(n) if "target" in d and isinstance(d["target"], LazyZarrArray): lazy_zarr_arrays.append(d["target"]) @@ -149,15 +151,14 @@ def _create_lazy_zarr_arrays(self, dag): # add new node and edges name = "create-arrays" op_name = name - pipeline = create_zarr_arrays(lazy_zarr_arrays, reserved_mem) + primitive_op = create_zarr_arrays(lazy_zarr_arrays, reserved_mem) dag.add_node( name, name=name, op_name=op_name, type="op", - pipeline=pipeline, - projected_mem=pipeline.projected_mem, - num_tasks=pipeline.num_tasks, + primitive_op=primitive_op, + pipeline=primitive_op.pipeline, ) dag.add_node( "arrays", @@ -212,8 +213,7 @@ def num_tasks(self, optimize_graph=True, optimize_function=None, resume=None): dag = self._finalize_dag(optimize_graph, optimize_function) tasks = 0 for _, node in visit_nodes(dag, resume=resume): - pipeline = node["pipeline"] - tasks += pipeline.num_tasks + tasks += node["primitive_op"].num_tasks return tasks def num_arrays(self, optimize_graph: bool = True, optimize_function=None) -> int: @@ -227,7 +227,7 @@ def max_projected_mem( """Return the maximum projected memory across all tasks to execute this plan.""" dag = self._finalize_dag(optimize_graph, optimize_function) projected_mem_values = [ - node["pipeline"].projected_mem + node["primitive_op"].projected_mem for _, node in visit_nodes(dag, resume=resume) ] return max(projected_mem_values) if len(projected_mem_values) > 0 else 0 @@ -314,16 +314,18 @@ def visualize( op_name_summary = "" tooltip += f"op: {op_name}" - if "pipeline" in d: - pipeline = d["pipeline"] + if "primitive_op" in d: + primitive_op = d["primitive_op"] tooltip += ( - f"\nprojected memory: {memory_repr(pipeline.projected_mem)}" + f"\nprojected memory: {memory_repr(primitive_op.projected_mem)}" ) - tooltip += f"\ntasks: {pipeline.num_tasks}" - if pipeline.write_chunks is not None: - tooltip += f"\nwrite chunks: {pipeline.write_chunks}" + tooltip += f"\ntasks: {primitive_op.num_tasks}" + if primitive_op.write_chunks is not None: + tooltip += f"\nwrite chunks: {primitive_op.write_chunks}" + del d["primitive_op"] - # remove pipeline attribute since it is a long string that causes graphviz to fail + # remove pipeline attribute since it is a long string that causes graphviz to fail + if "pipeline" in d: del d["pipeline"] if "stack_summaries" in d and d["stack_summaries"] is not None: @@ -433,14 +435,16 @@ def create_zarr_arrays(lazy_zarr_arrays, reserved_mem): ) num_tasks = len(lazy_zarr_arrays) - return CubedPipeline( + pipeline = CubedPipeline( create_zarr_array, "create_zarr_array", lazy_zarr_arrays, None, - None, - projected_mem, - reserved_mem, - num_tasks, - None, + ) + return PrimitiveOperation( + pipeline=pipeline, + target_array=None, + projected_mem=projected_mem, + reserved_mem=reserved_mem, + num_tasks=num_tasks, ) diff --git a/cubed/extensions/history.py b/cubed/extensions/history.py index a5f854591..5c1d2e7e5 100644 --- a/cubed/extensions/history.py +++ b/cubed/extensions/history.py @@ -12,14 +12,14 @@ class HistoryCallback(Callback): def on_compute_start(self, dag, resume): plan = [] for name, node in visit_nodes(dag, resume): - pipeline = node["pipeline"] + primitive_op = node["primitive_op"] plan.append( dict( array_name=name, op_name=node["op_name"], - projected_mem=pipeline.projected_mem, - reserved_mem=pipeline.reserved_mem, - num_tasks=pipeline.num_tasks, + projected_mem=primitive_op.projected_mem, + reserved_mem=primitive_op.reserved_mem, + num_tasks=primitive_op.num_tasks, ) ) diff --git a/cubed/extensions/tqdm.py b/cubed/extensions/tqdm.py index 94b4c94d4..7ad071599 100644 --- a/cubed/extensions/tqdm.py +++ b/cubed/extensions/tqdm.py @@ -20,7 +20,7 @@ def on_compute_start(self, dag, resume): self.pbars = {} i = 0 for name, node in visit_nodes(dag, resume): - num_tasks = node["pipeline"].num_tasks + num_tasks = node["primitive_op"].num_tasks self.pbars[name] = tqdm( *self.args, desc=name, total=num_tasks, position=i, **self.kwargs ) diff --git a/cubed/primitive/blockwise.py b/cubed/primitive/blockwise.py index 1b2466d69..0e895756c 100644 --- a/cubed/primitive/blockwise.py +++ b/cubed/primitive/blockwise.py @@ -20,7 +20,7 @@ from cubed.vendor.dask.blockwise import _get_coord_mapping, _make_dims, lol_product from cubed.vendor.dask.core import flatten -from .types import CubedArrayProxy +from .types import CubedArrayProxy, PrimitiveOperation sym_counter = 0 @@ -252,7 +252,7 @@ def general_blockwise( Returns ------- - CubedPipeline to run the operation + PrimitiveOperation to run the operation """ array_names = in_names or [f"in_{i}" for i in range(len(arrays))] array_map = {name: array for name, array in zip(array_names, arrays)} @@ -298,51 +298,62 @@ def general_blockwise( output_blocks = map(list, itertools.product(*[range(len(c)) for c in chunks])) num_tasks = math.prod(len(c) for c in chunks) - return CubedPipeline( + pipeline = CubedPipeline( apply_blockwise, gensym("apply_blockwise"), output_blocks, spec, - target_array, - projected_mem, - reserved_mem, - num_tasks, - None, + ) + return PrimitiveOperation( + pipeline=pipeline, + target_array=target_array, + projected_mem=projected_mem, + reserved_mem=reserved_mem, + num_tasks=num_tasks, ) -# Code for fusing pipelines +# Code for fusing blockwise operations -def is_fuse_candidate(pipeline: CubedPipeline) -> bool: +def is_fuse_candidate(primitive_op: PrimitiveOperation) -> bool: """ - Return True if a pipeline is a candidate for blockwise fusion. + Return True if a primitive operation is a candidate for blockwise fusion. """ - return pipeline.function == apply_blockwise + return primitive_op.pipeline.function == apply_blockwise -def can_fuse_pipelines(pipeline1: CubedPipeline, pipeline2: CubedPipeline) -> bool: - if is_fuse_candidate(pipeline1) and is_fuse_candidate(pipeline2): - return pipeline1.num_tasks == pipeline2.num_tasks +def can_fuse_primitive_ops( + primitive_op1: PrimitiveOperation, primitive_op2: PrimitiveOperation +) -> bool: + if is_fuse_candidate(primitive_op1) and is_fuse_candidate(primitive_op2): + return primitive_op1.num_tasks == primitive_op2.num_tasks return False -def can_fuse_multiple_pipelines( - pipeline: CubedPipeline, *predecessor_pipelines: CubedPipeline +def can_fuse_multiple_primitive_ops( + primitive_op: PrimitiveOperation, *predecessor_primitive_ops: PrimitiveOperation ) -> bool: - if is_fuse_candidate(pipeline) and all( - is_fuse_candidate(p) for p in predecessor_pipelines + if is_fuse_candidate(primitive_op) and all( + is_fuse_candidate(p) for p in predecessor_primitive_ops ): - return all(pipeline.num_tasks == p.num_tasks for p in predecessor_pipelines) + return all( + primitive_op.num_tasks == p.num_tasks for p in predecessor_primitive_ops + ) return False -def fuse(pipeline1: CubedPipeline, pipeline2: CubedPipeline) -> CubedPipeline: +def fuse( + primitive_op1: PrimitiveOperation, primitive_op2: PrimitiveOperation +) -> PrimitiveOperation: """ - Fuse two blockwise pipelines into a single pipeline, avoiding writing to (or reading from) the target of the first pipeline. + Fuse two blockwise operations into a single operation, avoiding writing to (or reading from) the target of the first operation. """ - assert pipeline1.num_tasks == pipeline2.num_tasks + assert primitive_op1.num_tasks == primitive_op2.num_tasks + + pipeline1 = primitive_op1.pipeline + pipeline2 = primitive_op2.pipeline mappable = pipeline2.mappable @@ -358,39 +369,47 @@ def fused_func(*args): write_proxy = pipeline2.config.write spec = BlockwiseSpec(fused_blockwise_func, fused_func, read_proxies, write_proxy) - target_array = pipeline2.target_array - projected_mem = max(pipeline1.projected_mem, pipeline2.projected_mem) - reserved_mem = max(pipeline1.reserved_mem, pipeline2.reserved_mem) - num_tasks = pipeline2.num_tasks + target_array = primitive_op2.target_array + projected_mem = max(primitive_op1.projected_mem, primitive_op2.projected_mem) + reserved_mem = max(primitive_op1.reserved_mem, primitive_op2.reserved_mem) + num_tasks = primitive_op2.num_tasks - return CubedPipeline( + pipeline = CubedPipeline( apply_blockwise, gensym("fused_apply_blockwise"), mappable, spec, - target_array, - projected_mem, - reserved_mem, - num_tasks, - None, + ) + return PrimitiveOperation( + pipeline=pipeline, + target_array=target_array, + projected_mem=projected_mem, + reserved_mem=reserved_mem, + num_tasks=num_tasks, ) def fuse_multiple( - pipeline: CubedPipeline, - *predecessor_pipelines: CubedPipeline, + primitive_op: PrimitiveOperation, + *predecessor_primitive_ops: PrimitiveOperation, predecessor_funcs_nargs=None, -) -> CubedPipeline: +) -> PrimitiveOperation: """ - Fuse a blockwise pipeline and its predecessors into a single pipeline, avoiding writing to (or reading from) the targets of the predecessor pipelines. + Fuse a blockwise operation and its predecessors into a single operation, avoiding writing to (or reading from) the targets of the predecessor operations. """ assert all( - pipeline.num_tasks == p.num_tasks - for p in predecessor_pipelines + primitive_op.num_tasks == p.num_tasks + for p in predecessor_primitive_ops if p is not None ) + pipeline = primitive_op.pipeline + predecessor_pipelines = [ + primitive_op.pipeline if primitive_op is not None else None + for primitive_op in predecessor_primitive_ops + ] + mappable = pipeline.mappable def apply_pipeline_block_func(pipeline, arg): @@ -430,27 +449,29 @@ def fused_func(*args): write_proxy = pipeline.config.write spec = BlockwiseSpec(fused_blockwise_func, fused_func, read_proxies, write_proxy) - target_array = pipeline.target_array + target_array = primitive_op.target_array projected_mem = max( - pipeline.projected_mem, - *(p.projected_mem for p in predecessor_pipelines if p is not None), + primitive_op.projected_mem, + *(p.projected_mem for p in predecessor_primitive_ops if p is not None), ) reserved_mem = max( - pipeline.reserved_mem, - *(p.reserved_mem for p in predecessor_pipelines if p is not None), + primitive_op.reserved_mem, + *(p.reserved_mem for p in predecessor_primitive_ops if p is not None), ) - num_tasks = pipeline.num_tasks + num_tasks = primitive_op.num_tasks - return CubedPipeline( + fused_pipeline = CubedPipeline( apply_blockwise, gensym("fused_apply_blockwise"), mappable, spec, - target_array, - projected_mem, - reserved_mem, - num_tasks, - None, + ) + return PrimitiveOperation( + pipeline=fused_pipeline, + target_array=target_array, + projected_mem=projected_mem, + reserved_mem=reserved_mem, + num_tasks=num_tasks, ) diff --git a/cubed/primitive/rechunk.py b/cubed/primitive/rechunk.py index 5c91ca2d6..ba28e2802 100644 --- a/cubed/primitive/rechunk.py +++ b/cubed/primitive/rechunk.py @@ -5,7 +5,7 @@ import numpy as np -from cubed.primitive.types import CubedArrayProxy, CubedCopySpec +from cubed.primitive.types import CubedArrayProxy, CubedCopySpec, PrimitiveOperation from cubed.runtime.types import CubedPipeline from cubed.storage.zarr import T_ZarrArray, lazy_empty from cubed.types import T_RegularChunks, T_Shape, T_Store @@ -27,7 +27,7 @@ def rechunk( reserved_mem: int, target_store: T_Store, temp_store: Optional[T_Store] = None, -) -> List[CubedPipeline]: +) -> List[PrimitiveOperation]: """Change the chunking of an array, without changing its shape or dtype. Parameters @@ -46,7 +46,7 @@ def rechunk( Returns ------- - CubedPipeline to run the operation + PrimitiveOperation to run the operation """ # rechunker doesn't take account of uncompressed and compressed copies of the @@ -71,24 +71,26 @@ def rechunk( copy_spec = CubedCopySpec(read_proxy, write_proxy) num_tasks = total_chunks(write_proxy.array.shape, write_proxy.chunks) return [ - spec_to_pipeline(copy_spec, target, projected_mem, reserved_mem, num_tasks) + spec_to_primitive_op( + copy_spec, target, projected_mem, reserved_mem, num_tasks + ) ] else: # break spec into two if there's an intermediate copy_spec1 = CubedCopySpec(read_proxy, int_proxy) num_tasks = total_chunks(copy_spec1.write.array.shape, copy_spec1.write.chunks) - pipeline1 = spec_to_pipeline( + op1 = spec_to_primitive_op( copy_spec1, intermediate, projected_mem, reserved_mem, num_tasks ) copy_spec2 = CubedCopySpec(int_proxy, write_proxy) num_tasks = total_chunks(copy_spec2.write.array.shape, copy_spec2.write.chunks) - pipeline2 = spec_to_pipeline( + op2 = spec_to_primitive_op( copy_spec2, target, projected_mem, reserved_mem, num_tasks ) - return [pipeline1, pipeline2] + return [op1, op2] # from rechunker, but simpler since it only has to handle Zarr arrays @@ -185,23 +187,26 @@ def copy_read_to_write(chunk_key: Sequence[slice], *, config: CubedCopySpec) -> config.write.open()[chunk_key] = data -def spec_to_pipeline( +def spec_to_primitive_op( spec: CubedCopySpec, target_array: Any, projected_mem: int, reserved_mem: int, num_tasks: int, -) -> CubedPipeline: +) -> PrimitiveOperation: # typing won't work until we start using numpy types shape = spec.read.array.shape # type: ignore - return CubedPipeline( + pipeline = CubedPipeline( copy_read_to_write, gensym("copy_read_to_write"), ChunkKeys(shape, spec.write.chunks), spec, - target_array, - projected_mem, - reserved_mem, - num_tasks, - spec.write.chunks, + ) + return PrimitiveOperation( + pipeline=pipeline, + target_array=target_array, + projected_mem=projected_mem, + reserved_mem=reserved_mem, + num_tasks=num_tasks, + write_chunks=spec.write.chunks, ) diff --git a/cubed/primitive/types.py b/cubed/primitive/types.py index a6f25089c..f4df42c72 100644 --- a/cubed/primitive/types.py +++ b/cubed/primitive/types.py @@ -1,11 +1,25 @@ from dataclasses import dataclass +from typing import Any, Optional import zarr +from cubed.runtime.types import CubedPipeline from cubed.storage.zarr import T_ZarrArray, open_if_lazy_zarr_array from cubed.types import T_RegularChunks +@dataclass(frozen=True) +class PrimitiveOperation: + """Encapsulates metadata about a ``blockwise`` or ``rechunk`` primitive operation.""" + + pipeline: CubedPipeline + target_array: Any + projected_mem: int + reserved_mem: int + num_tasks: int + write_chunks: Optional[T_RegularChunks] = None + + class CubedArrayProxy: """Generalisation of rechunker ``ArrayProxy`` with support for ``LazyZarrArray``.""" diff --git a/cubed/runtime/types.py b/cubed/runtime/types.py index 4b3188b9e..cd6da9a1b 100644 --- a/cubed/runtime/types.py +++ b/cubed/runtime/types.py @@ -1,9 +1,8 @@ from dataclasses import dataclass -from typing import Any, Iterable, Optional +from typing import Iterable, Optional from networkx import MultiDiGraph -from cubed.types import T_RegularChunks from cubed.vendor.rechunker.types import Config, StageFunction @@ -23,11 +22,6 @@ class CubedPipeline: name: str mappable: Iterable config: Config - target_array: Any - projected_mem: int - reserved_mem: int - num_tasks: int - write_chunks: Optional[T_RegularChunks] class Callback: diff --git a/cubed/tests/primitive/test_blockwise.py b/cubed/tests/primitive/test_blockwise.py index 0427e5ca0..f9554fcbe 100644 --- a/cubed/tests/primitive/test_blockwise.py +++ b/cubed/tests/primitive/test_blockwise.py @@ -30,7 +30,7 @@ def test_blockwise(tmp_path, executor, reserved_mem): allowed_mem = 1000 target_store = tmp_path / "target.zarr" - pipeline = blockwise( + op = blockwise( nxp.linalg.outer, "ij", source1, @@ -45,12 +45,12 @@ def test_blockwise(tmp_path, executor, reserved_mem): chunks=(2, 2), ) - assert pipeline.target_array.shape == (3, 3) - assert pipeline.target_array.dtype == int - assert pipeline.target_array.chunks == (2, 2) + assert op.target_array.shape == (3, 3) + assert op.target_array.dtype == int + assert op.target_array.chunks == (2, 2) itemsize = np.dtype(int).itemsize - assert pipeline.projected_mem == ( + assert op.projected_mem == ( reserved_mem # projected includes reserved + (itemsize * 2) # source1 compressed chunk + (itemsize * 2) # source1 uncompressed chunk @@ -60,11 +60,11 @@ def test_blockwise(tmp_path, executor, reserved_mem): + (itemsize * 2 * 2) # output uncompressed chunk ) - assert pipeline.num_tasks == 4 + assert op.num_tasks == 4 - pipeline.target_array.create() # create lazy zarr array + op.target_array.create() # create lazy zarr array - execute_pipeline(pipeline, executor=executor) + execute_pipeline(op.pipeline, executor=executor) res = zarr.open_array(target_store) assert_array_equal(res[:], np.outer([0, 1, 2], [10, 50, 100])) @@ -103,7 +103,7 @@ def test_blockwise_with_args(tmp_path, executor): allowed_mem = 1000 target_store = tmp_path / "target.zarr" - pipeline = _permute_dims( + op = _permute_dims( source, axes=(1, 0), allowed_mem=allowed_mem, @@ -111,23 +111,23 @@ def test_blockwise_with_args(tmp_path, executor): target_store=target_store, ) - assert pipeline.target_array.shape == (3, 3) - assert pipeline.target_array.dtype == int - assert pipeline.target_array.chunks == (2, 2) + assert op.target_array.shape == (3, 3) + assert op.target_array.dtype == int + assert op.target_array.chunks == (2, 2) itemsize = np.dtype(int).itemsize - assert pipeline.projected_mem == ( + assert op.projected_mem == ( (itemsize * 2 * 2) # source compressed chunk + (itemsize * 2 * 2) # source uncompressed chunk + (itemsize * 2 * 2) # output compressed chunk + (itemsize * 2 * 2) # output uncompressed chunk ) - assert pipeline.num_tasks == 4 + assert op.num_tasks == 4 - pipeline.target_array.create() # create lazy zarr array + op.target_array.create() # create lazy zarr array - execute_pipeline(pipeline, executor=executor) + execute_pipeline(op.pipeline, executor=executor) res = zarr.open_array(target_store) assert_array_equal( @@ -197,7 +197,7 @@ def block_function(out_key): ], ) - pipeline = general_blockwise( + op = general_blockwise( merge_chunks, block_function, source, @@ -210,15 +210,15 @@ def block_function(out_key): in_names=[in_name], ) - assert pipeline.target_array.shape == (20,) - assert pipeline.target_array.dtype == int - assert pipeline.target_array.chunks == (6,) + assert op.target_array.shape == (20,) + assert op.target_array.dtype == int + assert op.target_array.chunks == (6,) - assert pipeline.num_tasks == 4 + assert op.num_tasks == 4 - pipeline.target_array.create() # create lazy zarr array + op.target_array.create() # create lazy zarr array - execute_pipeline(pipeline, executor=executor) + execute_pipeline(op.pipeline, executor=executor) res = zarr.open_array(target_store) assert_array_equal(res[:], np.arange(20)) diff --git a/cubed/tests/primitive/test_rechunk.py b/cubed/tests/primitive/test_rechunk.py index cc6b8a1b9..34d524e10 100644 --- a/cubed/tests/primitive/test_rechunk.py +++ b/cubed/tests/primitive/test_rechunk.py @@ -63,7 +63,7 @@ def test_rechunk( target_store = tmp_path / "target.zarr" temp_store = tmp_path / "temp.zarr" - pipelines = rechunk( + ops = rechunk( source, target_chunks=target_chunks, allowed_mem=allowed_mem, @@ -72,25 +72,25 @@ def test_rechunk( temp_store=temp_store, ) - assert len(pipelines) == len(expected_num_tasks) + assert len(ops) == len(expected_num_tasks) - for i, pipeline in enumerate(pipelines): - assert pipeline.target_array.shape == shape - assert pipeline.target_array.dtype == source.dtype + for i, op in enumerate(ops): + assert op.target_array.shape == shape + assert op.target_array.dtype == source.dtype - assert pipeline.projected_mem == expected_projected_mem + assert op.projected_mem == expected_projected_mem - assert pipeline.num_tasks == expected_num_tasks[i] + assert op.num_tasks == expected_num_tasks[i] - last_pipeline = pipelines[-1] - assert last_pipeline.target_array.chunks == target_chunks + last_op = ops[-1] + assert last_op.target_array.chunks == target_chunks # create lazy zarr arrays - for pipeline in pipelines: - pipeline.target_array.create() + for op in ops: + op.target_array.create() - for pipeline in pipelines: - execute_pipeline(pipeline, executor=executor) + for op in ops: + execute_pipeline(op.pipeline, executor=executor) res = zarr.open_array(target_store) assert_array_equal(res[:], np.ones(shape))