From fb12b09683cc01007bc4df3fc0cb5a9718d297d5 Mon Sep 17 00:00:00 2001 From: Tom White Date: Sun, 19 May 2024 10:33:01 +0100 Subject: [PATCH] Limited implementation of pad (#461) --- cubed/__init__.py | 2 ++ cubed/array_api/manipulation_functions.py | 19 ++++++++++----- cubed/pad.py | 29 +++++++++++++++++++++++ cubed/tests/test_pad.py | 21 ++++++++++++++++ docs/api.rst | 1 + 5 files changed, 66 insertions(+), 6 deletions(-) create mode 100644 cubed/pad.py create mode 100644 cubed/tests/test_pad.py diff --git a/cubed/__init__.py b/cubed/__init__.py index ffa3abe6..7518ddc1 100644 --- a/cubed/__init__.py +++ b/cubed/__init__.py @@ -19,6 +19,7 @@ from .core.ops import from_array, from_zarr, map_blocks, store, to_zarr from .nan_functions import nanmean, nansum from .overlap import map_overlap +from .pad import pad from .runtime.types import Callback, TaskEndEvent from .spec import Spec @@ -38,6 +39,7 @@ "measure_reserved_mem", "nanmean", "nansum", + "pad", "store", "to_zarr", "visualize", diff --git a/cubed/array_api/manipulation_functions.py b/cubed/array_api/manipulation_functions.py index 9dbb2417..115dc1f3 100644 --- a/cubed/array_api/manipulation_functions.py +++ b/cubed/array_api/manipulation_functions.py @@ -73,7 +73,7 @@ def _broadcast_like(x, template): return nxp.broadcast_to(x, template.shape) -def concat(arrays, /, *, axis=0): +def concat(arrays, /, *, axis=0, chunks=None): if not arrays: raise ValueError("Need array(s) to concat") @@ -93,7 +93,10 @@ def concat(arrays, /, *, axis=0): axis = validate_axis(axis, a.ndim) shape = a.shape[:axis] + (offsets[-1],) + a.shape[axis + 1 :] dtype = a.dtype - chunks = normalize_chunks(to_chunksize(a.chunks), shape=shape, dtype=dtype) + if chunks is None: + chunks = normalize_chunks(to_chunksize(a.chunks), shape=shape, dtype=dtype) + else: + chunks = normalize_chunks(chunks, shape=shape, dtype=dtype) # memory allocated by reading one chunk from input array # note that although the output chunk will overlap multiple input chunks, @@ -107,20 +110,24 @@ def concat(arrays, /, *, axis=0): dtype=dtype, chunks=chunks, extra_projected_mem=extra_projected_mem, + target_chunks=chunks, axis=axis, offsets=offsets, ) -def _read_concat_chunk(x, *arrays, axis=None, offsets=None, block_id=None): +def _read_concat_chunk( + x, *arrays, target_chunks=None, axis=None, offsets=None, block_id=None +): # determine the start and stop indexes for this block along the axis dimension - chunks = arrays[0].zarray.chunks - start = block_id[axis] * chunks[axis] + chunks = target_chunks + chunksize = to_chunksize(chunks) + start = block_id[axis] * chunksize[axis] stop = start + x.shape[axis] # produce a key that has slices (except for axis dimension, which is replaced below) idx = tuple(0 if i == axis else v for i, v in enumerate(block_id)) - key = get_item(arrays[0].chunks, idx) + key = get_item(chunks, idx) # concatenate slices of the arrays parts = [] diff --git a/cubed/pad.py b/cubed/pad.py new file mode 100644 index 00000000..c292c65e --- /dev/null +++ b/cubed/pad.py @@ -0,0 +1,29 @@ +from cubed.array_api.manipulation_functions import concat + + +def pad(x, pad_width, mode=None, chunks=None): + """Pad an array.""" + if len(pad_width) != x.ndim: + raise ValueError("`pad_width` must have as many entries as array dimensions") + axis = tuple( + i + for (i, (before, after)) in enumerate(pad_width) + if (before != 0 or after != 0) + ) + if len(axis) != 1: + raise ValueError("only one axis can be padded") + axis = axis[0] + if pad_width[axis] != (1, 0): + raise ValueError("only a pad width of (1, 0) is allowed") + if mode != "symmetric": + raise ValueError(f"Mode is not supported: {mode}") + + select = [] + for i in range(x.ndim): + if i == axis: + select.append(slice(0, 1)) + else: + select.append(slice(None)) + select = tuple(select) + a = x[select] + return concat([a, x], axis=axis, chunks=chunks or x.chunksize) diff --git a/cubed/tests/test_pad.py b/cubed/tests/test_pad.py new file mode 100644 index 00000000..7ba985f4 --- /dev/null +++ b/cubed/tests/test_pad.py @@ -0,0 +1,21 @@ +import numpy as np +import pytest +from numpy.testing import assert_array_equal + +import cubed +import cubed.array_api as xp + + +@pytest.fixture() +def spec(tmp_path): + return cubed.Spec(tmp_path, allowed_mem=100000) + + +def test_pad(spec): + an = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]]) + + a = xp.asarray([[1, 2, 3], [4, 5, 6], [7, 8, 9]], chunks=(2, 2), spec=spec) + b = cubed.pad(a, ((1, 0), (0, 0)), mode="symmetric") + assert b.chunks == ((2, 2), (2, 1)) + + assert_array_equal(b.compute(), np.pad(an, ((1, 0), (0, 0)), mode="symmetric")) diff --git a/docs/api.rst b/docs/api.rst index ab29bee0..3c887b37 100644 --- a/docs/api.rst +++ b/docs/api.rst @@ -57,6 +57,7 @@ Non-standardised functions nanmean nansum + pad Random number generation ========================