Skip to content

Commit

Permalink
Implement reshape_chunks using general_blockwise
Browse files Browse the repository at this point in the history
  • Loading branch information
tomwhite committed Jan 20, 2024
1 parent ea760d1 commit 2e68a7c
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 18 deletions.
36 changes: 18 additions & 18 deletions cubed/array_api/manipulation_functions.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
from bisect import bisect
from itertools import product
from operator import add, mul

import numpy as np
Expand All @@ -8,11 +7,10 @@

from cubed.array_api.creation_functions import empty
from cubed.backend_array_api import namespace as nxp
from cubed.backend_array_api import numpy_array_to_backend_array
from cubed.core import squeeze # noqa: F401
from cubed.core import blockwise, rechunk, unify_chunks
from cubed.core.ops import elemwise, general_blockwise, map_blocks, map_direct
from cubed.utils import get_item, to_chunksize
from cubed.utils import block_id_to_offset, get_item, offset_to_block_id, to_chunksize
from cubed.vendor.dask.array.core import broadcast_chunks, normalize_chunks
from cubed.vendor.dask.array.reshape import reshape_rechunk
from cubed.vendor.dask.array.utils import validate_axis
Expand Down Expand Up @@ -245,33 +243,35 @@ def reshape_chunks(x, shape, chunks):
if reduce(mul, shape, 1) != x.size:
raise ValueError("total size of new array must be unchanged")

inchunks = normalize_chunks(x.chunks, shape=x.shape, dtype=x.dtype)
# TODO: check number of chunks is unchanged
# inchunks = normalize_chunks(x.chunks, shape=x.shape, dtype=x.dtype)
outchunks = normalize_chunks(chunks, shape=shape, dtype=x.dtype)

# TODO: check number of chunks is unchanged
# use an empty template (handles smaller end chunks)
template = empty(shape, dtype=x.dtype, chunks=chunks, spec=x.spec)

# memory allocated by reading one chunk from input array
extra_projected_mem = x.chunkmem
def block_function(out_key):
out_coords = out_key[1:]
offset = block_id_to_offset(out_coords, template.numblocks)
in_coords = offset_to_block_id(offset, x.numblocks)
return (
(x.name, *in_coords),
(template.name, *out_coords),
)

return map_direct(
return general_blockwise(
_reshape_chunk,
block_function,
x,
template,
shape=shape,
dtype=x.dtype,
chunks=outchunks,
extra_projected_mem=extra_projected_mem,
inchunks=inchunks,
outchunks=outchunks,
)


def _reshape_chunk(e, x, inchunks=None, outchunks=None, block_id=None):
in_keys = list(product(*[range(len(c)) for c in inchunks]))
out_keys = list(product(*[range(len(c)) for c in outchunks]))
idx = in_keys[out_keys.index(block_id)]
out = x.zarray[get_item(x.chunks, idx)]
out = numpy_array_to_backend_array(out)
return nxp.reshape(out, e.shape)
def _reshape_chunk(x, template):
return nxp.reshape(x, template.shape)


def stack(arrays, /, *, axis=0):
Expand Down
13 changes: 13 additions & 0 deletions cubed/tests/test_array_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -480,6 +480,19 @@ def test_reshape_chunks(spec, executor):
)


def test_reshape_chunks_with_smaller_end_chunk(spec, executor):
a = xp.arange(10, chunks=4, spec=spec)
b = reshape_chunks(a, (2, 5), (2, 2))

assert b.shape == (2, 5)
assert b.chunks == ((2,), (2, 2, 1))

assert_array_equal(
b.compute(executor=executor),
np.array([[0, 1, 4, 5, 8], [2, 3, 6, 7, 9]]),
)


def test_squeeze_1d(spec, executor):
a = xp.asarray([[1, 2, 3]], chunks=(1, 2), spec=spec)
b = xp.squeeze(a, 0)
Expand Down

0 comments on commit 2e68a7c

Please sign in to comment.