diff --git a/cubed/array_api/creation_functions.py b/cubed/array_api/creation_functions.py index fdf1dfbd..717ec725 100644 --- a/cubed/array_api/creation_functions.py +++ b/cubed/array_api/creation_functions.py @@ -1,8 +1,6 @@ import math from typing import TYPE_CHECKING, Iterable, List -from zarr.util import normalize_shape - from cubed.backend_array_api import namespace as nxp from cubed.core import Plan, gensym from cubed.core.ops import map_blocks @@ -12,7 +10,7 @@ virtual_in_memory, virtual_offsets, ) -from cubed.utils import to_chunksize +from cubed.utils import normalize_shape, to_chunksize from cubed.vendor.dask.array.core import normalize_chunks if TYPE_CHECKING: diff --git a/cubed/random.py b/cubed/random.py index 9852a229..6c60a6c9 100644 --- a/cubed/random.py +++ b/cubed/random.py @@ -1,12 +1,11 @@ import random as pyrandom from numpy.random import Generator, Philox -from zarr.util import normalize_shape from cubed.backend_array_api import namespace as nxp from cubed.backend_array_api import numpy_array_to_backend_array from cubed.core.ops import map_blocks -from cubed.utils import block_id_to_offset +from cubed.utils import block_id_to_offset, normalize_shape from cubed.vendor.dask.array.core import normalize_chunks diff --git a/cubed/tests/test_utils.py b/cubed/tests/test_utils.py index f6ea3a14..81a1358a 100644 --- a/cubed/tests/test_utils.py +++ b/cubed/tests/test_utils.py @@ -15,6 +15,7 @@ join_path, map_nested, memory_repr, + normalize_shape, offset_to_block_id, peak_measured_mem, split_into, @@ -170,3 +171,13 @@ def test_broadcast_trick(): a = nxp.ones((), dtype=nxp.int8) b = broadcast_trick(nxp.ones)((), dtype=nxp.int8) assert_array_equal(a, b) + + +def test_normalize_shape(): + assert normalize_shape(2) == (2,) + assert normalize_shape((2,)) == (2,) + assert normalize_shape((2, 0)) == (2, 0) + assert normalize_shape((2, 3)) == (2, 3) + + with pytest.raises(TypeError): + normalize_shape(None) diff --git a/cubed/utils.py b/cubed/utils.py index 7b962e8f..fb01220e 100644 --- a/cubed/utils.py +++ b/cubed/utils.py @@ -1,5 +1,6 @@ import collections import itertools +import numbers import platform import sys import sysconfig @@ -12,7 +13,7 @@ from operator import add from pathlib import Path from posixpath import join -from typing import Dict, Tuple, Union +from typing import Dict, Tuple, Union, cast from urllib.parse import quote, unquote, urlsplit, urlunsplit import numpy as np @@ -385,3 +386,17 @@ def _concatenate2(arrays, axes=None): return ret else: return concatenate(arrays, axis=axes[0]) + + +def normalize_shape(shape: Union[int, Tuple[int, ...], None]) -> Tuple[int, ...]: + """Normalize a `shape` argument to a tuple of ints.""" + + if shape is None: + raise TypeError("shape is None") + + if isinstance(shape, numbers.Integral): + shape = (int(shape),) + + shape = cast(Tuple[int, ...], shape) + shape = tuple(int(s) for s in shape) + return shape