Skip to content

Commit

Permalink
Limited implementation of pad (#461)
Browse files Browse the repository at this point in the history
  • Loading branch information
tomwhite authored May 19, 2024
1 parent b57fa0c commit fb12b09
Show file tree
Hide file tree
Showing 5 changed files with 66 additions and 6 deletions.
2 changes: 2 additions & 0 deletions cubed/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from .core.ops import from_array, from_zarr, map_blocks, store, to_zarr
from .nan_functions import nanmean, nansum
from .overlap import map_overlap
from .pad import pad
from .runtime.types import Callback, TaskEndEvent
from .spec import Spec

Expand All @@ -38,6 +39,7 @@
"measure_reserved_mem",
"nanmean",
"nansum",
"pad",
"store",
"to_zarr",
"visualize",
Expand Down
19 changes: 13 additions & 6 deletions cubed/array_api/manipulation_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ def _broadcast_like(x, template):
return nxp.broadcast_to(x, template.shape)


def concat(arrays, /, *, axis=0):
def concat(arrays, /, *, axis=0, chunks=None):
if not arrays:
raise ValueError("Need array(s) to concat")

Expand All @@ -93,7 +93,10 @@ def concat(arrays, /, *, axis=0):
axis = validate_axis(axis, a.ndim)
shape = a.shape[:axis] + (offsets[-1],) + a.shape[axis + 1 :]
dtype = a.dtype
chunks = normalize_chunks(to_chunksize(a.chunks), shape=shape, dtype=dtype)
if chunks is None:
chunks = normalize_chunks(to_chunksize(a.chunks), shape=shape, dtype=dtype)
else:
chunks = normalize_chunks(chunks, shape=shape, dtype=dtype)

# memory allocated by reading one chunk from input array
# note that although the output chunk will overlap multiple input chunks,
Expand All @@ -107,20 +110,24 @@ def concat(arrays, /, *, axis=0):
dtype=dtype,
chunks=chunks,
extra_projected_mem=extra_projected_mem,
target_chunks=chunks,
axis=axis,
offsets=offsets,
)


def _read_concat_chunk(x, *arrays, axis=None, offsets=None, block_id=None):
def _read_concat_chunk(
x, *arrays, target_chunks=None, axis=None, offsets=None, block_id=None
):
# determine the start and stop indexes for this block along the axis dimension
chunks = arrays[0].zarray.chunks
start = block_id[axis] * chunks[axis]
chunks = target_chunks
chunksize = to_chunksize(chunks)
start = block_id[axis] * chunksize[axis]
stop = start + x.shape[axis]

# produce a key that has slices (except for axis dimension, which is replaced below)
idx = tuple(0 if i == axis else v for i, v in enumerate(block_id))
key = get_item(arrays[0].chunks, idx)
key = get_item(chunks, idx)

# concatenate slices of the arrays
parts = []
Expand Down
29 changes: 29 additions & 0 deletions cubed/pad.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
from cubed.array_api.manipulation_functions import concat


def pad(x, pad_width, mode=None, chunks=None):
"""Pad an array."""
if len(pad_width) != x.ndim:
raise ValueError("`pad_width` must have as many entries as array dimensions")
axis = tuple(
i
for (i, (before, after)) in enumerate(pad_width)
if (before != 0 or after != 0)
)
if len(axis) != 1:
raise ValueError("only one axis can be padded")
axis = axis[0]
if pad_width[axis] != (1, 0):
raise ValueError("only a pad width of (1, 0) is allowed")
if mode != "symmetric":
raise ValueError(f"Mode is not supported: {mode}")

select = []
for i in range(x.ndim):
if i == axis:
select.append(slice(0, 1))
else:
select.append(slice(None))
select = tuple(select)
a = x[select]
return concat([a, x], axis=axis, chunks=chunks or x.chunksize)
21 changes: 21 additions & 0 deletions cubed/tests/test_pad.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
import numpy as np
import pytest
from numpy.testing import assert_array_equal

import cubed
import cubed.array_api as xp


@pytest.fixture()
def spec(tmp_path):
return cubed.Spec(tmp_path, allowed_mem=100000)


def test_pad(spec):
an = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]])

a = xp.asarray([[1, 2, 3], [4, 5, 6], [7, 8, 9]], chunks=(2, 2), spec=spec)
b = cubed.pad(a, ((1, 0), (0, 0)), mode="symmetric")
assert b.chunks == ((2, 2), (2, 1))

assert_array_equal(b.compute(), np.pad(an, ((1, 0), (0, 0)), mode="symmetric"))
1 change: 1 addition & 0 deletions docs/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ Non-standardised functions

nanmean
nansum
pad

Random number generation
========================
Expand Down

0 comments on commit fb12b09

Please sign in to comment.