-
Notifications
You must be signed in to change notification settings - Fork 14
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Limited implementation of map_overlap
- Loading branch information
Showing
5 changed files
with
259 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,139 @@ | ||
from typing import Tuple | ||
|
||
from cubed.backend_array_api import namespace as nxp | ||
from cubed.core.ops import map_direct | ||
from cubed.types import T_RectangularChunks | ||
from cubed.utils import _cumsum | ||
from cubed.vendor.dask.array.core import normalize_chunks | ||
from cubed.vendor.dask.array.overlap import coerce_boundary, coerce_depth | ||
from cubed.vendor.dask.utils import has_keyword | ||
|
||
|
||
def map_overlap( | ||
func, | ||
*args, | ||
dtype=None, | ||
chunks=None, | ||
depth=None, | ||
boundary=None, | ||
trim=False, | ||
**kwargs, | ||
): | ||
"""Apply a function to corresponding blocks from multiple input arrays with some overlap. | ||
Parameters | ||
---------- | ||
func : callable | ||
Function to apply to every block (with overlap) to produce the output array. | ||
args : arrays | ||
The Cubed arrays to map over. Note that currently only one array may be specified. | ||
dtype : np.dtype | ||
The ``dtype`` of the output array. | ||
chunks : tuple | ||
Chunk shape of blocks in the output array. | ||
depth : int, tuple, dict or list | ||
The number of elements that each block should share with its neighbors. | ||
boundary : value type, tuple, dict or list | ||
How to handle the boundaries. Note that this currently only supports constant values. | ||
trim : bool | ||
Whether or not to trim ``depth`` elements from each block after calling the map function. | ||
Currently only ``False`` is supported. | ||
**kwargs : dict | ||
Extra keyword arguments to pass to function. | ||
""" | ||
if trim: | ||
raise ValueError("trim is not supported") | ||
|
||
chunks = normalize_chunks(chunks, dtype=dtype) | ||
shape = tuple(map(sum, chunks)) | ||
|
||
# Coerce depth and boundary arguments to lists of individual | ||
# specifications for each array argument | ||
def coerce(xs, arg, fn): | ||
if not isinstance(arg, list): | ||
arg = [arg] * len(xs) | ||
return [fn(x.ndim, a) for x, a in zip(xs, arg)] | ||
|
||
depth = coerce(args, depth, coerce_depth) | ||
boundary = coerce(args, boundary, coerce_boundary) | ||
|
||
# memory allocated by reading one chunk from input array | ||
# note that although the output chunk will overlap multiple input chunks, zarr will | ||
# read the chunks in series, reusing the buffer | ||
extra_projected_mem = args[0].chunkmem # TODO: support multiple | ||
|
||
has_block_id_kw = has_keyword(func, "block_id") | ||
|
||
return map_direct( | ||
_overlap, | ||
*args, | ||
shape=shape, | ||
dtype=dtype, | ||
chunks=chunks, | ||
extra_projected_mem=extra_projected_mem, | ||
overlap_func=func, | ||
depth=depth, | ||
boundary=boundary, | ||
has_block_id_kw=has_block_id_kw, | ||
**kwargs, | ||
) | ||
|
||
|
||
def _overlap( | ||
x, | ||
*arrays, | ||
overlap_func=None, | ||
depth=None, | ||
boundary=None, | ||
has_block_id_kw=False, | ||
block_id=None, | ||
**kwargs, | ||
): | ||
a = arrays[0] # TODO: support multiple | ||
depth = depth[0] | ||
boundary = boundary[0] | ||
|
||
# First read the chunk with overlaps determined by depth, then pad boundaries second. | ||
# Do it this way round so we can do everything with one blockwise. The alternative, | ||
# which pads the entire array first (via concatenate), would result in at least one extra copy. | ||
out = a.zarray[get_item_with_depth(a.chunks, block_id, depth)] | ||
out = _pad_boundaries(out, depth, boundary, a.numblocks, block_id) | ||
if has_block_id_kw: | ||
return overlap_func(out, block_id=block_id, **kwargs) | ||
else: | ||
return overlap_func(out, **kwargs) | ||
|
||
|
||
def _clamp(minimum: int, x: int, maximum: int) -> int: | ||
return max(minimum, min(x, maximum)) | ||
|
||
|
||
def get_item_with_depth( | ||
chunks: T_RectangularChunks, idx: Tuple[int, ...], depth | ||
) -> Tuple[slice, ...]: | ||
"""Convert a chunk index to a tuple of slices with depth offsets.""" | ||
starts = tuple(_cumsum(c, initial_zero=True) for c in chunks) | ||
loc = tuple( | ||
( | ||
_clamp(0, start[i] - depth[ax], start[-1]), | ||
_clamp(0, start[i + 1] + depth[ax], start[-1]), | ||
) | ||
for ax, (i, start) in enumerate(zip(idx, starts)) | ||
) | ||
return tuple(slice(*s, None) for s in loc) | ||
|
||
|
||
def _pad_boundaries(x, depth, boundary, numblocks, block_id): | ||
for i in range(x.ndim): | ||
d = depth.get(i, 0) | ||
if d == 0 or block_id[i] not in (0, numblocks[i] - 1): | ||
continue | ||
pad_shape = list(x.shape) | ||
pad_shape[i] = d | ||
pad_shape = tuple(pad_shape) | ||
p = nxp.full_like(x, fill_value=boundary[i], shape=pad_shape) | ||
if block_id[i] == 0: # first block on axis i | ||
x = nxp.concatenate([p, x], axis=i) | ||
elif block_id[i] == numblocks[i] - 1: # last block on axis i | ||
x = nxp.concatenate([x, p], axis=i) | ||
return x |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,79 @@ | ||
import numpy as np | ||
from numpy.testing import assert_array_equal | ||
|
||
import cubed | ||
import cubed.array_api as xp | ||
|
||
|
||
def test_map_overlap_1d(): | ||
x = np.arange(6) | ||
a = xp.asarray(x, chunks=(3,)) | ||
|
||
b = cubed.map_overlap( | ||
lambda x: x, | ||
a, | ||
dtype=a.dtype, | ||
chunks=((5, 5),), | ||
depth=1, | ||
boundary=0, | ||
trim=False, | ||
) | ||
|
||
assert_array_equal(b.compute(), np.array([0, 0, 1, 2, 3, 2, 3, 4, 5, 0])) | ||
|
||
|
||
def test_map_overlap_2d(): | ||
x = np.arange(36).reshape((6, 6)) | ||
a = xp.asarray(x, chunks=(3, 3)) | ||
|
||
b = cubed.map_overlap( | ||
lambda x: x, | ||
a, | ||
dtype=a.dtype, | ||
chunks=((7, 7), (5, 5)), | ||
depth={0: 2, 1: 1}, | ||
boundary={0: 100, 1: 200}, | ||
trim=False, | ||
) | ||
|
||
expected = np.array( | ||
[ | ||
[200, 100, 100, 100, 100, 100, 100, 100, 100, 200], | ||
[200, 100, 100, 100, 100, 100, 100, 100, 100, 200], | ||
[200, 0, 1, 2, 3, 2, 3, 4, 5, 200], | ||
[200, 6, 7, 8, 9, 8, 9, 10, 11, 200], | ||
[200, 12, 13, 14, 15, 14, 15, 16, 17, 200], | ||
[200, 18, 19, 20, 21, 20, 21, 22, 23, 200], | ||
[200, 24, 25, 26, 27, 26, 27, 28, 29, 200], | ||
[200, 6, 7, 8, 9, 8, 9, 10, 11, 200], | ||
[200, 12, 13, 14, 15, 14, 15, 16, 17, 200], | ||
[200, 18, 19, 20, 21, 20, 21, 22, 23, 200], | ||
[200, 24, 25, 26, 27, 26, 27, 28, 29, 200], | ||
[200, 30, 31, 32, 33, 32, 33, 34, 35, 200], | ||
[200, 100, 100, 100, 100, 100, 100, 100, 100, 200], | ||
[200, 100, 100, 100, 100, 100, 100, 100, 100, 200], | ||
] | ||
) | ||
|
||
assert_array_equal(b.compute(), expected) | ||
|
||
|
||
def test_map_overlap_trim(): | ||
x = np.array([1, 1, 2, 3, 5, 8, 13, 21]) | ||
a = xp.asarray(x, chunks=5) | ||
|
||
def derivative(x): | ||
out = x - np.roll(x, 1) | ||
return out[1:-1] # manual trim | ||
|
||
b = cubed.map_overlap( | ||
derivative, | ||
a, | ||
dtype=a.dtype, | ||
chunks=a.chunks, | ||
depth=1, | ||
boundary=0, | ||
trim=False, | ||
) | ||
|
||
assert_array_equal(b.compute(), np.array([1, 0, 1, 1, 2, 3, 5, 8])) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,38 @@ | ||
from __future__ import annotations | ||
|
||
from numbers import Integral | ||
|
||
|
||
def coerce_depth(ndim, depth): | ||
default = 0 | ||
if depth is None: | ||
depth = default | ||
if isinstance(depth, Integral): | ||
depth = (depth,) * ndim | ||
if isinstance(depth, tuple): | ||
depth = dict(zip(range(ndim), depth)) | ||
if isinstance(depth, dict): | ||
depth = {ax: depth.get(ax, default) for ax in range(ndim)} | ||
return coerce_depth_type(ndim, depth) | ||
|
||
|
||
def coerce_depth_type(ndim, depth): | ||
for i in range(ndim): | ||
if isinstance(depth[i], tuple): | ||
depth[i] = tuple(int(d) for d in depth[i]) | ||
else: | ||
depth[i] = int(depth[i]) | ||
return depth | ||
|
||
|
||
def coerce_boundary(ndim, boundary): | ||
default = "none" | ||
if boundary is None: | ||
boundary = default | ||
if not isinstance(boundary, (tuple, dict)): | ||
boundary = (boundary,) * ndim | ||
if isinstance(boundary, tuple): | ||
boundary = dict(zip(range(ndim), boundary)) | ||
if isinstance(boundary, dict): | ||
boundary = {ax: boundary.get(ax, default) for ax in range(ndim)} | ||
return boundary |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters