diff --git a/cubed/array_api/manipulation_functions.py b/cubed/array_api/manipulation_functions.py index 4948d203..93ea2341 100644 --- a/cubed/array_api/manipulation_functions.py +++ b/cubed/array_api/manipulation_functions.py @@ -47,14 +47,16 @@ def broadcast_to(x, /, shape, *, chunks=None): ): raise ValueError(f"cannot broadcast shape {x.shape} to shape {shape}") - # TODO: fix case where shape has a dimension of size zero - if chunks is None: # New dimensions and broadcast dimensions have chunk size 1 # This behaviour differs from dask where it is the full dimension size xchunks = normalize_chunks(x.chunks, x.shape, dtype=x.dtype) - chunks = tuple((1,) * s for s in shape[:ndim_new]) + tuple( - bd if old > 1 else ((1,) * new if new > 0 else (0,)) + + def chunklen(shapelen): + return (1,) * shapelen if shapelen > 0 else (0,) + + chunks = tuple(chunklen(s) for s in shape[:ndim_new]) + tuple( + bd if old > 1 else chunklen(new) for bd, old, new in zip(xchunks, x.shape, shape[ndim_new:]) ) else: diff --git a/cubed/tests/test_array_api.py b/cubed/tests/test_array_api.py index 4ea6653c..327e6726 100644 --- a/cubed/tests/test_array_api.py +++ b/cubed/tests/test_array_api.py @@ -474,7 +474,8 @@ def test_broadcast_arrays(executor): @pytest.mark.parametrize( "shape, chunks, new_shape, new_chunks, new_chunks_expected", [ - # ((5, 1, 6), (3, 1, 3), (5, 0, 6), None, ((3, 2), (0,), (3, 3))), # fails + ((), (), (0,), None, ((0,),)), + ((5, 1, 6), (3, 1, 3), (5, 0, 6), None, ((3, 2), (0,), (3, 3))), ((5, 1, 6), (3, 1, 3), (5, 4, 6), None, ((3, 2), (1, 1, 1, 1), (3, 3))), ((5, 1, 6), (3, 1, 3), (2, 5, 1, 6), None, ((1, 1), (3, 2), (1,), (3, 3))), ((5, 1, 6), (3, 1, 3), (5, 3, 6), (3, 3, 3), ((3, 2), (3,), (3, 3))),