Skip to content

Commit

Permalink
adding back unfold(), no idea what happened
Browse files Browse the repository at this point in the history
  • Loading branch information
ClaudiaComito committed Sep 4, 2024
1 parent eb39460 commit d1175e8
Showing 1 changed file with 86 additions and 0 deletions.
86 changes: 86 additions & 0 deletions heat/core/manipulations.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@
"swapaxes",
"tile",
"topk",
"unfold",
"unique",
"vsplit",
"vstack",
Expand Down Expand Up @@ -3480,6 +3481,91 @@ def unique(
DNDarray.unique.__doc__ = unique.__doc__


def unfold(a: DNDarray, axis: int, size: int, step: int = 1):
"""
Returns a DNDarray which contains all slices of size `size` in the axis `axis`.
Behaves like torch.Tensor.unfold for DNDarrays. [torch.Tensor.unfold](https://pytorch.org/docs/stable/generated/torch.Tensor.unfold.html)
Parameters
----------
a : DNDarray
array to unfold
axis : int
axis in which unfolding happens
size : int
the size of each slice that is unfolded, must be greater than 1
step : int
the step between each slice, must be at least 1
Example:
```
>>> x = ht.arange(1., 8)
>>> x
DNDarray([1., 2., 3., 4., 5., 6., 7.], dtype=ht.float32, device=cpu:0, split=e)
>>> ht.unfold(x, 0, 2, 1)
DNDarray([[1., 2.],
[2., 3.],
[3., 4.],
[4., 5.],
[5., 6.],
[6., 7.]], dtype=ht.float32, device=cpu:0, split=None)
>>> ht.unfold(x, 0, 2, 2)
DNDarray([[1., 2.],
[3., 4.],
[5., 6.]], dtype=ht.float32, device=cpu:0, split=None)
```
Note
---------
You have to make sure that every node has at least chunk size size-1 if the split axis of the array is the unfold axis.
"""
if step < 1:
raise ValueError("step must be >= 1.")
if size <= 1:
raise ValueError("size must be > 1.")
axis = stride_tricks.sanitize_axis(a.shape, axis)
if size > a.shape[axis]:
raise ValueError(
f"maximum size for DNDarray at axis {axis} is {a.shape[axis]} but size is {size}."
)

comm = a.comm
dev = a.device
tdev = dev.torch_device

if a.split is None or comm.size == 1 or a.split != axis: # early out
ret = factories.array(
a.larray.unfold(axis, size, step), is_split=a.split, device=dev, comm=comm
)

return ret
else: # comm.size > 1 and split axis == unfold axis
# index range [0:sizedim-1-(size-1)] = [0:sizedim-size]
# --> size of axis: ceil((sizedim-size+1) / step) = floor(sizedim-size) / step)) + 1
# ret_shape = (*a_shape[:axis], int((a_shape[axis]-size)/step) + 1, a_shape[axis+1:], size)

if (size - 1 > a.lshape_map[:, axis]).any():
raise RuntimeError("Chunk-size needs to be at least size - 1.")
a.get_halo(size - 1, prev=False)

counts, displs = a.counts_displs()
displs = torch.tensor(displs, device=tdev)

# min local index in unfold axis
min_index = ((displs[comm.rank] - 1) // step + 1) * step - displs[comm.rank]
if min_index >= a.lshape[axis] or (
comm.rank == comm.size - 1 and min_index + size > a.lshape[axis]
):
loc_unfold_shape = list(a.lshape)
loc_unfold_shape[axis] = 0
ret_larray = torch.zeros((*loc_unfold_shape, size), device=tdev)
else: # unfold has local data
ret_larray = a.array_with_halos[
axis * (slice(None, None, None),) + (slice(min_index, None, None), Ellipsis)
].unfold(axis, size, step)

ret = factories.array(ret_larray, is_split=axis, device=dev, comm=comm)

return ret


def vsplit(x: DNDarray, indices_or_sections: Iterable) -> List[DNDarray, ...]:
"""
Split array into multiple sub-DNDNarrays along the 1st axis (vertically/row-wise).
Expand Down

0 comments on commit d1175e8

Please sign in to comment.