diff --git a/api_status.md b/api_status.md index 76925976..a71b54c7 100644 --- a/api_status.md +++ b/api_status.md @@ -61,7 +61,7 @@ This table shows which parts of the the [Array API](https://data-apis.org/array- | | `expand_dims` | :white_check_mark: | | | | | `flip` | :white_check_mark: | | | | | `permute_dims` | :white_check_mark: | | | -| | `repeat` | :x: | 2023.12 | | +| | `repeat` | :white_check_mark: | | | | | `reshape` | :white_check_mark: | | Partial implementation | | | `roll` | :white_check_mark: | | | | | `squeeze` | :white_check_mark: | | | diff --git a/cubed/__init__.py b/cubed/__init__.py index ada6c2f2..31c56240 100644 --- a/cubed/__init__.py +++ b/cubed/__init__.py @@ -296,6 +296,7 @@ flip, moveaxis, permute_dims, + repeat, reshape, roll, squeeze, @@ -311,6 +312,7 @@ "flip", "moveaxis", "permute_dims", + "repeat", "reshape", "roll", "squeeze", diff --git a/cubed/array_api/__init__.py b/cubed/array_api/__init__.py index 66c3edd8..587242f9 100644 --- a/cubed/array_api/__init__.py +++ b/cubed/array_api/__init__.py @@ -238,6 +238,7 @@ flip, moveaxis, permute_dims, + repeat, reshape, roll, squeeze, @@ -253,6 +254,7 @@ "flip", "moveaxis", "permute_dims", + "repeat", "reshape", "roll", "squeeze", @@ -264,7 +266,7 @@ __all__ += ["argmax", "argmin", "where"] -from .statistical_functions import max, mean, min, prod, sum, std, var +from .statistical_functions import max, mean, min, prod, std, sum, var __all__ += ["max", "mean", "min", "prod", "std", "sum", "var"] diff --git a/cubed/array_api/manipulation_functions.py b/cubed/array_api/manipulation_functions.py index 35eb76bc..284cda0e 100644 --- a/cubed/array_api/manipulation_functions.py +++ b/cubed/array_api/manipulation_functions.py @@ -352,6 +352,50 @@ def permute_dims(x, /, axes): ) +def repeat(x, repeats, /, *, axis=0): + if axis is None: + x = flatten(x) + axis = 0 + + shape = x.shape[:axis] + (x.shape[axis] * repeats,) + x.shape[axis + 1 :] + chunks = normalize_chunks(x.chunksize, shape=shape, dtype=x.dtype) + + # This implementation calls nxp.repeat in every output block, which is 'repeats' times + # more than necessary than if we had a primitive op that could write multiple blocks. + + def key_function(out_key): + out_coords = out_key[1:] + in_coords = tuple( + bi // repeats if i == axis else bi for i, bi in enumerate(out_coords) + ) + return ((x.name, *in_coords),) + + # extra memory from calling 'nxp.repeat' on a chunk + extra_projected_mem = x.chunkmem * repeats + return general_blockwise( + _repeat, + key_function, + x, + shapes=[shape], + dtypes=[x.dtype], + chunkss=[chunks], + extra_projected_mem=extra_projected_mem, + repeats=repeats, + axis=axis, + chunksize=x.chunksize, + ) + + +def _repeat(x, repeats, axis=None, chunksize=None, block_id=None): + out = nxp.repeat(x, repeats, axis=axis) + bi = block_id[axis] % repeats + ind = tuple( + slice(bi * chunksize[i], (bi + 1) * chunksize[i]) if i == axis else slice(None) + for i in range(x.ndim) + ) + return out[ind] + + def reshape(x, /, shape, *, copy=None): # based on dask reshape diff --git a/cubed/tests/test_array_api.py b/cubed/tests/test_array_api.py index 9c09d667..42a245cb 100644 --- a/cubed/tests/test_array_api.py +++ b/cubed/tests/test_array_api.py @@ -557,6 +557,15 @@ def test_permute_dims(spec, executor): ) +def test_repeat(spec): + a = xp.asarray([[1, 2, 3], [4, 5, 6], [7, 8, 9]], chunks=(2, 2), spec=spec) + b = xp.repeat(a, 3, axis=1) + assert_array_equal( + b.compute(), + np.repeat(np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]]), 3, axis=1), + ) + + def test_reshape(spec, executor): a = xp.arange(12, chunks=4, spec=spec) b = xp.reshape(a, (3, 4)) diff --git a/cubed/tests/test_mem_utilization.py b/cubed/tests/test_mem_utilization.py index be02973c..c6e73d6a 100644 --- a/cubed/tests/test_mem_utilization.py +++ b/cubed/tests/test_mem_utilization.py @@ -285,6 +285,15 @@ def test_flip_multiple_axes(tmp_path, spec, executor): run_operation(tmp_path, executor, "flip_multiple_axes", b) +@pytest.mark.slow +def test_repeat(tmp_path, spec, executor): + a = cubed.random.random( + (10000, 10000), chunks=(5000, 5000), spec=spec + ) # 200MB chunks + b = xp.repeat(a, 3, axis=0) + run_operation(tmp_path, executor, "repeat", b) + + @pytest.mark.slow def test_reshape(tmp_path, spec, executor): a = cubed.random.random(