Skip to content

Commit

Permalink
Track source array names in PrimitiveOperation to fix a bug with argu…
Browse files Browse the repository at this point in the history
…ment ordering

The bug was due to a mismatch between the order of the input source arrays
in the DAG and the function argument order for the operation. Since the DAG
order is not stable, the operation now records the names of the input source
arrays in the source_array_names variable.
  • Loading branch information
tomwhite committed Mar 8, 2024
1 parent f567e02 commit dba223a
Show file tree
Hide file tree
Showing 7 changed files with 84 additions and 16 deletions.
2 changes: 2 additions & 0 deletions cubed/core/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -744,6 +744,8 @@ def rechunk(x, chunks, target_store=None):
reserved_mem=spec.reserved_mem,
target_store=target_store,
temp_store=temp_store,
source_array_name=name,
int_array_name=name_int,
)

from cubed.array_api import Array
Expand Down
29 changes: 17 additions & 12 deletions cubed/core/optimization.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,17 +81,22 @@ def gensym(name="op"):
return f"{name}-{sym_counter:03}"


def predecessors(dag, name):
"""Return a node's predecessors, with repeats for multiple edges."""
def predecessors_unordered(dag, name):
"""Return a node's predecessors in no particular order, with repeats for multiple edges."""
for pre, _ in dag.in_edges(name):
yield pre


def predecessor_ops(dag, name):
"""Return an op node's op predecessors"""
for input in predecessors(dag, name):
for pre in predecessors(dag, input):
yield pre
"""Return an op node's op predecessors in the same order as the input source arrays for the op.
Note that each input source array is produced by a single predecessor op.
"""
nodes = dict(dag.nodes(data=True))
for input in nodes[name]["primitive_op"].source_array_names:
pre_list = list(predecessors_unordered(dag, input))
assert len(pre_list) == 1 # each array is produced by a single op
yield pre_list[0]


def is_fusable(node_dict):
Expand Down Expand Up @@ -135,7 +140,7 @@ def can_fuse_predecessors(
# the fused function would be more than an allowed maximum, then don't fuse
if len(list(predecessor_ops(dag, name))) > 1:
total_source_arrays = sum(
len(list(predecessors(dag, pre))) if is_fusable(nodes[pre]) else 1
len(list(predecessors_unordered(dag, pre))) if is_fusable(nodes[pre]) else 1
for pre in predecessor_ops(dag, name)
)
if total_source_arrays > max_total_source_arrays:
Expand Down Expand Up @@ -203,8 +208,8 @@ def fuse_predecessors(
# re-wire dag to remove predecessor nodes that have been fused

# 1. update edges to change inputs
for input in predecessors(dag, name):
pre = next(predecessors(dag, input))
for input in predecessors_unordered(dag, name):
pre = next(predecessors_unordered(dag, input))
if not is_fusable(fused_nodes[pre]):
# if a predecessor is not fusable then don't change the edge
continue
Expand All @@ -213,14 +218,14 @@ def fuse_predecessors(
if not is_fusable(fused_nodes[pre]):
# if a predecessor is not fusable then don't change the edge
continue
for input in predecessors(dag, pre):
for input in predecessors_unordered(dag, pre):
fused_dag.add_edge(input, name)

# 2. remove predecessor nodes with no successors
# (ones with successors are needed by other nodes)
for input in predecessors(dag, name):
for input in predecessors_unordered(dag, name):
if fused_dag.out_degree(input) == 0:
for pre in list(predecessors(fused_dag, input)):
for pre in list(predecessors_unordered(fused_dag, input)):
fused_dag.remove_node(pre)
fused_dag.remove_node(input)

Expand Down
1 change: 1 addition & 0 deletions cubed/core/plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -462,6 +462,7 @@ def create_zarr_arrays(lazy_zarr_arrays, allowed_mem, reserved_mem):
)
return PrimitiveOperation(
pipeline=pipeline,
source_array_names=None,
target_array=None,
projected_mem=projected_mem,
allowed_mem=allowed_mem,
Expand Down
10 changes: 10 additions & 0 deletions cubed/primitive/blockwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -331,6 +331,7 @@ def general_blockwise(
)
return PrimitiveOperation(
pipeline=pipeline,
source_array_names=array_names,
target_array=target_array,
projected_mem=projected_mem,
allowed_mem=allowed_mem,
Expand Down Expand Up @@ -483,6 +484,7 @@ def fused_func(*args):
write_proxy,
)

source_array_names = primitive_op1.source_array_names
target_array = primitive_op2.target_array
projected_mem = max(primitive_op1.projected_mem, primitive_op2.projected_mem)
allowed_mem = primitive_op2.allowed_mem
Expand All @@ -497,6 +499,7 @@ def fused_func(*args):
)
return PrimitiveOperation(
pipeline=pipeline,
source_array_names=source_array_names,
target_array=target_array,
projected_mem=projected_mem,
allowed_mem=allowed_mem,
Expand Down Expand Up @@ -607,6 +610,12 @@ def fused_func(*args):
write_proxy,
)

source_array_names = []
for i, p in enumerate(predecessor_primitive_ops):
if p is None:
source_array_names.append(primitive_op.source_array_names[i])
else:
source_array_names.extend(p.source_array_names)
target_array = primitive_op.target_array
projected_mem = max(
primitive_op.projected_mem,
Expand All @@ -624,6 +633,7 @@ def fused_func(*args):
)
return PrimitiveOperation(
pipeline=fused_pipeline,
source_array_names=source_array_names,
target_array=target_array,
projected_mem=projected_mem,
allowed_mem=allowed_mem,
Expand Down
21 changes: 19 additions & 2 deletions cubed/primitive/rechunk.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@ def rechunk(
reserved_mem: int,
target_store: T_Store,
temp_store: Optional[T_Store] = None,
source_array_name: Optional[str] = None,
int_array_name: Optional[str] = None,
) -> List[PrimitiveOperation]:
"""Change the chunking of an array, without changing its shape or dtype.
Expand Down Expand Up @@ -72,7 +74,13 @@ def rechunk(
num_tasks = total_chunks(write_proxy.array.shape, write_proxy.chunks)
return [
spec_to_primitive_op(
copy_spec, target, projected_mem, allowed_mem, reserved_mem, num_tasks
copy_spec,
source_array_name,
target,
projected_mem,
allowed_mem,
reserved_mem,
num_tasks,
)
]

Expand All @@ -82,6 +90,7 @@ def rechunk(
num_tasks = total_chunks(copy_spec1.write.array.shape, copy_spec1.write.chunks)
op1 = spec_to_primitive_op(
copy_spec1,
source_array_name,
intermediate,
projected_mem,
allowed_mem,
Expand All @@ -92,7 +101,13 @@ def rechunk(
copy_spec2 = CubedCopySpec(int_proxy, write_proxy)
num_tasks = total_chunks(copy_spec2.write.array.shape, copy_spec2.write.chunks)
op2 = spec_to_primitive_op(
copy_spec2, target, projected_mem, allowed_mem, reserved_mem, num_tasks
copy_spec2,
int_array_name,
target,
projected_mem,
allowed_mem,
reserved_mem,
num_tasks,
)

return [op1, op2]
Expand Down Expand Up @@ -184,6 +199,7 @@ def copy_read_to_write(chunk_key: Sequence[slice], *, config: CubedCopySpec) ->

def spec_to_primitive_op(
spec: CubedCopySpec,
source_array_name: Optional[str],
target_array: Any,
projected_mem: int,
allowed_mem: int,
Expand All @@ -200,6 +216,7 @@ def spec_to_primitive_op(
)
return PrimitiveOperation(
pipeline=pipeline,
source_array_names=[source_array_name],
target_array=target_array,
projected_mem=projected_mem,
allowed_mem=allowed_mem,
Expand Down
5 changes: 4 additions & 1 deletion cubed/primitive/types.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from dataclasses import dataclass
from typing import Any, Optional
from typing import Any, List, Optional

import zarr

Expand All @@ -15,6 +15,9 @@ class PrimitiveOperation:
pipeline: CubedPipeline
"""The pipeline that runs this operation."""

source_array_names: Optional[List[str]]
"""The names of the arrays which are inputs to this operation."""

target_array: Any
"""The array being computed by this operation."""

Expand Down
32 changes: 31 additions & 1 deletion cubed/tests/test_optimization.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,7 +192,8 @@ def add_placeholder_op(dag, inputs, outputs):
def structurally_equivalent(dag1, dag2):
# compare structure, and node labels for values but not operators since they are placeholders

# draw_dag(dag1) # uncomment for debugging
# draw_dag(dag1, "dag1") # uncomment for debugging
# draw_dag(dag2, "dag2") # uncomment for debugging

labelled_dag1 = nx.convert_node_labels_to_integers(dag1, label_attribute="label")
labelled_dag2 = nx.convert_node_labels_to_integers(dag2, label_attribute="label")
Expand All @@ -209,6 +210,10 @@ def nm(node_attrs1, node_attrs2):


def draw_dag(dag, name="dag"):
dag = dag.copy()
for _, d in dag.nodes(data=True):
if "name" in d: # pydot already has name
del d["name"]
gv = nx.drawing.nx_pydot.to_pydot(dag)
format = "svg"
full_filename = f"{name}.{format}"
Expand Down Expand Up @@ -429,6 +434,31 @@ def test_fuse_mixed_levels_and_diamond(spec):
assert_array_equal(result, 2 * np.ones((2, 2)))


# derived from a bug found by array_api_tests/test_manipulation_functions.py::test_expand_dims
# a b -> a b
# \ / |\ /|
# c | d |
# /| | | |
# d | | e |
# | | \|/
# e | f
# \|
# f
def test_fuse_mixed_levels_and_diamond_complex(spec):
a = xp.ones((2, 2), chunks=(2, 2), spec=spec)
b = xp.ones((2, 2), chunks=(2, 2), spec=spec)
c = xp.add(a, b)
d = xp.positive(c)
e = d[1:, :] # operation can't be fused
f = xp.add(e, c) # this order exposed a bug in argument ordering

opt_fn = multiple_inputs_optimize_dag

f.visualize(optimize_function=opt_fn)
result = f.compute(optimize_function=opt_fn)
assert_array_equal(result, 4 * np.ones((2, 2)))


# repeated argument
# from https://github.com/cubed-dev/cubed/issues/65
#
Expand Down

0 comments on commit dba223a

Please sign in to comment.