Skip to content

Commit

Permalink
Mark stack operation as not fusable (#415)
Browse files Browse the repository at this point in the history
  • Loading branch information
tomwhite authored Mar 6, 2024
1 parent be962a4 commit f567e02
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 0 deletions.
5 changes: 5 additions & 0 deletions cubed/array_api/manipulation_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -296,6 +296,10 @@ def block_function(out_key):
in_name = array_names[out_coords[axis]]
return ((in_name, *(out_coords[:axis] + out_coords[(axis + 1) :])),)

# We have to mark this as fusable=False since the number of input args to
# the _read_stack_chunk function is *not* the same as the number of
# predecessor nodes in the DAG, and the fusion functions in blockwise
# assume they are the same. See https://github.com/cubed-dev/cubed/issues/414
return general_blockwise(
_read_stack_chunk,
block_function,
Expand All @@ -304,6 +308,7 @@ def block_function(out_key):
dtype=dtype,
chunks=chunks,
axis=axis,
fusable=False,
)


Expand Down
10 changes: 10 additions & 0 deletions cubed/tests/test_optimization.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

import cubed
import cubed.array_api as xp
import cubed.random
from cubed.backend_array_api import namespace as nxp
from cubed.core.ops import elemwise, merge_chunks_new, partial_reduce
from cubed.core.optimization import (
Expand Down Expand Up @@ -907,3 +908,12 @@ def test_fuse_only_optimize_dag(spec):

result = d.compute(optimize_function=opt_fn)
assert_array_equal(result, -np.ones((2, 2)))


def test_optimize_stack(spec):
# This test fails if stack's general_blockwise call doesn't have fusable=False
a = cubed.random.random((10, 10), chunks=(5, 5), spec=spec)
b = cubed.random.random((10, 10), chunks=(5, 5), spec=spec)
c = xp.stack((a, b), axis=0)
# try to fuse all ops into one
c.compute(optimize_function=fuse_multiple_levels(max_total_num_input_blocks=10))

0 comments on commit f567e02

Please sign in to comment.