Skip to content

Commit

Permalink
Fix broadcast_to when target shape has size 0
Browse files Browse the repository at this point in the history
  • Loading branch information
tomwhite committed Dec 5, 2024
1 parent 9f543ad commit bee866c
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 5 deletions.
10 changes: 6 additions & 4 deletions cubed/array_api/manipulation_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,14 +47,16 @@ def broadcast_to(x, /, shape, *, chunks=None):
):
raise ValueError(f"cannot broadcast shape {x.shape} to shape {shape}")

# TODO: fix case where shape has a dimension of size zero

if chunks is None:
# New dimensions and broadcast dimensions have chunk size 1
# This behaviour differs from dask where it is the full dimension size
xchunks = normalize_chunks(x.chunks, x.shape, dtype=x.dtype)
chunks = tuple((1,) * s for s in shape[:ndim_new]) + tuple(
bd if old > 1 else ((1,) * new if new > 0 else (0,))

def chunklen(shapelen):
return (1,) * shapelen if shapelen > 0 else (0,)

chunks = tuple(chunklen(s) for s in shape[:ndim_new]) + tuple(
bd if old > 1 else chunklen(new)
for bd, old, new in zip(xchunks, x.shape, shape[ndim_new:])
)
else:
Expand Down
3 changes: 2 additions & 1 deletion cubed/tests/test_array_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -474,7 +474,8 @@ def test_broadcast_arrays(executor):
@pytest.mark.parametrize(
"shape, chunks, new_shape, new_chunks, new_chunks_expected",
[
# ((5, 1, 6), (3, 1, 3), (5, 0, 6), None, ((3, 2), (0,), (3, 3))), # fails
((), (), (0,), None, ((0,),)),
((5, 1, 6), (3, 1, 3), (5, 0, 6), None, ((3, 2), (0,), (3, 3))),
((5, 1, 6), (3, 1, 3), (5, 4, 6), None, ((3, 2), (1, 1, 1, 1), (3, 3))),
((5, 1, 6), (3, 1, 3), (2, 5, 1, 6), None, ((1, 1), (3, 2), (1,), (3, 3))),
((5, 1, 6), (3, 1, 3), (5, 3, 6), (3, 3, 3), ((3, 2), (3,), (3, 3))),
Expand Down

0 comments on commit bee866c

Please sign in to comment.