Skip to content

Commit

Permalink
Implement map_overlap using map_selection
Browse files Browse the repository at this point in the history
  • Loading branch information
tomwhite committed Oct 29, 2024
1 parent 1852cf9 commit 3ea32ed
Showing 1 changed file with 28 additions and 18 deletions.
46 changes: 28 additions & 18 deletions cubed/overlap.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from typing import Tuple

from cubed.backend_array_api import namespace as nxp
from cubed.core.ops import map_direct
from cubed.core.ops import map_selection
from cubed.types import T_RectangularChunks
from cubed.utils import _cumsum
from cubed.vendor.dask.array.core import normalize_chunks
Expand Down Expand Up @@ -57,47 +57,57 @@ def coerce(xs, arg, fn):
depth = coerce(args, depth, coerce_depth)
boundary = coerce(args, boundary, coerce_boundary)

# memory allocated by reading one chunk from input array
# note that although the output chunk will overlap multiple input chunks, zarr will
# read the chunks in series, reusing the buffer
extra_projected_mem = args[0].chunkmem # TODO: support multiple
x = args[0] # TODO: support multiple input arrays

def selection_function(out_key):
out_coords = out_key[1:]
block_id = out_coords
return get_item_with_depth(x.chunks, block_id, depth[0])

max_num_input_blocks = _overlap_num_input_blocks(depth[0], x.numblocks)

has_block_id_kw = has_keyword(func, "block_id")

return map_direct(
# First read the chunk with overlaps determined by depth, then pad boundaries second.
# Do it this way round so we can do everything with one blockwise. The alternative,
# which pads the entire array first (via concatenate), would result in at least one extra copy.

return map_selection(
_overlap,
*args,
selection_function,
x,
shape=shape,
dtype=dtype,
chunks=chunks,
extra_projected_mem=extra_projected_mem,
max_num_input_blocks=max_num_input_blocks,
overlap_func=func,
depth=depth,
boundary=boundary,
numblocks=x.numblocks,
has_block_id_kw=has_block_id_kw,
**kwargs,
)


def _overlap_num_input_blocks(depth, numblocks):
num = 1
for i in depth.keys():
num *= min(numblocks[i], 3)
return num


def _overlap(
x,
*arrays,
a,
overlap_func=None,
depth=None,
boundary=None,
numblocks=None,
has_block_id_kw=False,
block_id=None,
**kwargs,
):
a = arrays[0] # TODO: support multiple
depth = depth[0]
boundary = boundary[0]

# First read the chunk with overlaps determined by depth, then pad boundaries second.
# Do it this way round so we can do everything with one blockwise. The alternative,
# which pads the entire array first (via concatenate), would result in at least one extra copy.
out = a.zarray[get_item_with_depth(a.chunks, block_id, depth)]
out = _pad_boundaries(out, depth, boundary, a.numblocks, block_id)
out = _pad_boundaries(a, depth, boundary, numblocks, block_id)
if has_block_id_kw:
return overlap_func(out, block_id=block_id, **kwargs)
else:
Expand Down

0 comments on commit 3ea32ed

Please sign in to comment.