Skip to content

Commit

Permalink
Update manipulations
Browse files Browse the repository at this point in the history
  • Loading branch information
ClaudiaComito committed Sep 4, 2024
1 parent e5eb012 commit eb39460
Showing 1 changed file with 75 additions and 98 deletions.
173 changes: 75 additions & 98 deletions heat/core/manipulations.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,6 @@
"unique",
"vsplit",
"vstack",
"unfold",
]


Expand Down Expand Up @@ -963,14 +962,36 @@ def expand_dims(a: DNDarray, axis: int) -> DNDarray:
# sanitize input
sanitation.sanitize_in(a)

# sanitize axis, introduce arbitrary dummy dimension to model expansion
axis = stride_tricks.sanitize_axis(a.shape + (1,), axis)
# track split axis
split_bookkeeping = [None] * a.ndim
if a.split is not None:
split_bookkeeping[a.split] = "split"
output_shape = list(a.shape)

local_expansion = a.larray
if isinstance(axis, (tuple, list)):
# sanitize axis, introduce arbitrary dummy dimensions to model expansion
axis = stride_tricks.sanitize_axis(a.shape + (1,) * len(axis), axis)
for ax in axis:
split_bookkeeping.insert(ax, None)
output_shape.insert(ax, 1)
local_expansion = local_expansion.unsqueeze(dim=ax)

else:
# sanitize axis, introduce arbitrary dummy dimensions to model expansion
axis = stride_tricks.sanitize_axis(a.shape + (1,), axis)
split_bookkeeping.insert(axis, None)
output_shape.insert(axis, 1)
local_expansion = local_expansion.unsqueeze(dim=axis)

output_split = split_bookkeeping.index("split") if "split" in split_bookkeeping else None
output_shape = tuple(output_shape)

return DNDarray(
a.larray.unsqueeze(dim=axis),
a.shape[:axis] + (1,) + a.shape[axis:],
local_expansion,
output_shape,
a.dtype,
a.split if a.split is None or a.split < axis else a.split + 1,
output_split,
a.device,
a.comm,
a.balanced,
Expand Down Expand Up @@ -2005,7 +2026,9 @@ def reshape(a: DNDarray, *shape: Union[int, Tuple[int, ...]], **kwargs) -> DNDar
Shape of the new array. Must be compatible with the original shape. If an integer, then the result will be a 1-D array of that length.
One shape dimension can be -1. In this case, the value is inferred from the length of the array and remaining dimensions.
new_split : int, optional
The distribution axis of the reshaped array. Default: None (same distribution axis as `a`).
The distribution axis of the reshaped array. If `new_split` is not provided, the reshaped array will have:
- the same split axis as the input array, if the number of dimensions is not reduced;
- split axis 0, if the number of dimensions is reduced.
Raises
------
Expand All @@ -2014,7 +2037,7 @@ def reshape(a: DNDarray, *shape: Union[int, Tuple[int, ...]], **kwargs) -> DNDar
Notes
-----
`reshape()` might require significant communication among processes. Operating along split axis 0 is recommended.
`reshape()` might require significant communication among processes. Communication is minimized if the input array is distributed along axis 0, i.e. `a.split == 0`.
See Also
--------
Expand All @@ -2032,6 +2055,44 @@ def reshape(a: DNDarray, *shape: Union[int, Tuple[int, ...]], **kwargs) -> DNDar
>>> ht.reshape(a, (2,4))
(1/2) tensor([[0., 2., 4., 6.]])
(2/2) tensor([[ 8., 10., 12., 14.]])
# 3-dim array, distributed along axis 1
>>> a = ht.random.rand(2, 3, 4, split=1)
>>> a
DNDarray([[[0.5525, 0.5434, 0.9477, 0.9503],
[0.4165, 0.3924, 0.3310, 0.3935],
[0.1008, 0.1750, 0.9030, 0.8579]],
[[0.0680, 0.4944, 0.4114, 0.6669],
[0.6423, 0.2625, 0.5413, 0.2225],
[0.0197, 0.5079, 0.4739, 0.4387]]], dtype=ht.float32, device=cpu:0, split=1)
>>> a.reshape(-1, 3) # reshape to 2-dim array: split axis will be set to 0
DNDarray([[0.5525, 0.5434, 0.9477],
[0.9503, 0.4165, 0.3924],
[0.3310, 0.3935, 0.1008],
[0.1750, 0.9030, 0.8579],
[0.0680, 0.4944, 0.4114],
[0.6669, 0.6423, 0.2625],
[0.5413, 0.2225, 0.0197],
[0.5079, 0.4739, 0.4387]], dtype=ht.float32, device=cpu:0, split=0)
>>> a.reshape(2,3,2,2, new_split=1) # reshape to 4-dim array, specify distribution axis
DNDarray([[[[0.5525, 0.5434],
[0.9477, 0.9503]],
[[0.4165, 0.3924],
[0.3310, 0.3935]],
[[0.1008, 0.1750],
[0.9030, 0.8579]]],
[[[0.0680, 0.4944],
[0.4114, 0.6669]],
[[0.6423, 0.2625],
[0.5413, 0.2225]],
[[0.0197, 0.5079],
[0.4739, 0.4387]]]], dtype=ht.float32, device=cpu:0, split=1)
"""
if not isinstance(a, DNDarray):
raise TypeError(f"'a' must be a DNDarray, currently {type(a)}")
Expand Down Expand Up @@ -2069,7 +2130,12 @@ def reshape(a: DNDarray, *shape: Union[int, Tuple[int, ...]], **kwargs) -> DNDar
# check new_split parameter
new_split = kwargs.get("new_split")
if new_split is None:
new_split = orig_split
if orig_split is not None and len(shape) != a.ndim:
# dimensionality reduced or expanded
# set output split axis to 0
new_split = 0
else:
new_split = orig_split
new_split = stride_tricks.sanitize_axis(shape, new_split)

if not a.is_distributed():
Expand Down Expand Up @@ -4214,92 +4280,3 @@ def mpi_topk(a, b, mpi_type):


MPI_TOPK = MPI.Op.Create(mpi_topk, commute=True)


def unfold(a: DNDarray, axis: int, size: int, step: int = 1):
"""
Returns a DNDarray which contains all slices of size `size` in the axis `axis`.
Behaves like torch.Tensor.unfold for DNDarrays. [torch.Tensor.unfold](https://pytorch.org/docs/stable/generated/torch.Tensor.unfold.html)
Parameters
----------
a : DNDarray
array to unfold
axis : int
axis in which unfolding happens
size : int
the size of each slice that is unfolded, must be greater than 1
step : int
the step between each slice, must be at least 1
Example:
```
>>> x = ht.arange(1., 8)
>>> x
DNDarray([1., 2., 3., 4., 5., 6., 7.], dtype=ht.float32, device=cpu:0, split=e)
>>> ht.unfold(x, 0, 2, 1)
DNDarray([[1., 2.],
[2., 3.],
[3., 4.],
[4., 5.],
[5., 6.],
[6., 7.]], dtype=ht.float32, device=cpu:0, split=None)
>>> ht.unfold(x, 0, 2, 2)
DNDarray([[1., 2.],
[3., 4.],
[5., 6.]], dtype=ht.float32, device=cpu:0, split=None)
```
Note
---------
You have to make sure that every node has at least chunk size size-1 if the split axis of the array is the unfold axis.
"""
if step < 1:
raise ValueError("step must be >= 1.")
if size <= 1:
raise ValueError("size must be > 1.")
axis = stride_tricks.sanitize_axis(a.shape, axis)
if size > a.shape[axis]:
raise ValueError(
f"maximum size for DNDarray at axis {axis} is {a.shape[axis]} but size is {size}."
)

comm = a.comm
dev = a.device
tdev = dev.torch_device

if a.split is None or comm.size == 1 or a.split != axis: # early out
ret = factories.array(
a.larray.unfold(axis, size, step), is_split=a.split, device=dev, comm=comm
)

return ret
else: # comm.size > 1 and split axis == unfold axis
# index range [0:sizedim-1-(size-1)] = [0:sizedim-size]
# --> size of axis: ceil((sizedim-size+1) / step) = floor(sizedim-size) / step)) + 1
# ret_shape = (*a_shape[:axis], int((a_shape[axis]-size)/step) + 1, a_shape[axis+1:], size)

if (size - 1 > a.lshape_map[:, axis]).any():
raise RuntimeError("Chunk-size needs to be at least size - 1.")
a.get_halo(size - 1, prev=False)

counts, displs = a.counts_displs()
displs = torch.tensor(displs, device=tdev)

# min local index in unfold axis
min_index = ((displs[comm.rank] - 1) // step + 1) * step - displs[comm.rank]
if min_index >= a.lshape[axis] or (
comm.rank == comm.size - 1 and min_index + size > a.lshape[axis]
):
loc_unfold_shape = list(a.lshape)
loc_unfold_shape[axis] = 0
ret_larray = torch.zeros((*loc_unfold_shape, size), device=tdev)
else: # unfold has local data
ret_larray = a.array_with_halos[
axis * (slice(None, None, None),) + (slice(min_index, None, None), Ellipsis)
].unfold(axis, size, step)

ret = factories.array(ret_larray, is_split=axis, device=dev, comm=comm)

return ret

0 comments on commit eb39460

Please sign in to comment.