From eb39460c9b47920008f6008e4fff6ef0e1ac640b Mon Sep 17 00:00:00 2001 From: Claudia Comito <39374113+ClaudiaComito@users.noreply.github.com> Date: Wed, 4 Sep 2024 12:37:04 +0200 Subject: [PATCH] Update manipulations --- heat/core/manipulations.py | 173 ++++++++++++++++--------------------- 1 file changed, 75 insertions(+), 98 deletions(-) diff --git a/heat/core/manipulations.py b/heat/core/manipulations.py index 4e600e208..401a21c73 100644 --- a/heat/core/manipulations.py +++ b/heat/core/manipulations.py @@ -61,7 +61,6 @@ "unique", "vsplit", "vstack", - "unfold", ] @@ -963,14 +962,36 @@ def expand_dims(a: DNDarray, axis: int) -> DNDarray: # sanitize input sanitation.sanitize_in(a) - # sanitize axis, introduce arbitrary dummy dimension to model expansion - axis = stride_tricks.sanitize_axis(a.shape + (1,), axis) + # track split axis + split_bookkeeping = [None] * a.ndim + if a.split is not None: + split_bookkeeping[a.split] = "split" + output_shape = list(a.shape) + + local_expansion = a.larray + if isinstance(axis, (tuple, list)): + # sanitize axis, introduce arbitrary dummy dimensions to model expansion + axis = stride_tricks.sanitize_axis(a.shape + (1,) * len(axis), axis) + for ax in axis: + split_bookkeeping.insert(ax, None) + output_shape.insert(ax, 1) + local_expansion = local_expansion.unsqueeze(dim=ax) + + else: + # sanitize axis, introduce arbitrary dummy dimensions to model expansion + axis = stride_tricks.sanitize_axis(a.shape + (1,), axis) + split_bookkeeping.insert(axis, None) + output_shape.insert(axis, 1) + local_expansion = local_expansion.unsqueeze(dim=axis) + + output_split = split_bookkeeping.index("split") if "split" in split_bookkeeping else None + output_shape = tuple(output_shape) return DNDarray( - a.larray.unsqueeze(dim=axis), - a.shape[:axis] + (1,) + a.shape[axis:], + local_expansion, + output_shape, a.dtype, - a.split if a.split is None or a.split < axis else a.split + 1, + output_split, a.device, a.comm, a.balanced, @@ -2005,7 +2026,9 @@ def reshape(a: DNDarray, *shape: Union[int, Tuple[int, ...]], **kwargs) -> DNDar Shape of the new array. Must be compatible with the original shape. If an integer, then the result will be a 1-D array of that length. One shape dimension can be -1. In this case, the value is inferred from the length of the array and remaining dimensions. new_split : int, optional - The distribution axis of the reshaped array. Default: None (same distribution axis as `a`). + The distribution axis of the reshaped array. If `new_split` is not provided, the reshaped array will have: + - the same split axis as the input array, if the number of dimensions is not reduced; + - split axis 0, if the number of dimensions is reduced. Raises ------ @@ -2014,7 +2037,7 @@ def reshape(a: DNDarray, *shape: Union[int, Tuple[int, ...]], **kwargs) -> DNDar Notes ----- - `reshape()` might require significant communication among processes. Operating along split axis 0 is recommended. + `reshape()` might require significant communication among processes. Communication is minimized if the input array is distributed along axis 0, i.e. `a.split == 0`. See Also -------- @@ -2032,6 +2055,44 @@ def reshape(a: DNDarray, *shape: Union[int, Tuple[int, ...]], **kwargs) -> DNDar >>> ht.reshape(a, (2,4)) (1/2) tensor([[0., 2., 4., 6.]]) (2/2) tensor([[ 8., 10., 12., 14.]]) + # 3-dim array, distributed along axis 1 + >>> a = ht.random.rand(2, 3, 4, split=1) + >>> a + DNDarray([[[0.5525, 0.5434, 0.9477, 0.9503], + [0.4165, 0.3924, 0.3310, 0.3935], + [0.1008, 0.1750, 0.9030, 0.8579]], + + [[0.0680, 0.4944, 0.4114, 0.6669], + [0.6423, 0.2625, 0.5413, 0.2225], + [0.0197, 0.5079, 0.4739, 0.4387]]], dtype=ht.float32, device=cpu:0, split=1) + >>> a.reshape(-1, 3) # reshape to 2-dim array: split axis will be set to 0 + DNDarray([[0.5525, 0.5434, 0.9477], + [0.9503, 0.4165, 0.3924], + [0.3310, 0.3935, 0.1008], + [0.1750, 0.9030, 0.8579], + [0.0680, 0.4944, 0.4114], + [0.6669, 0.6423, 0.2625], + [0.5413, 0.2225, 0.0197], + [0.5079, 0.4739, 0.4387]], dtype=ht.float32, device=cpu:0, split=0) + >>> a.reshape(2,3,2,2, new_split=1) # reshape to 4-dim array, specify distribution axis + DNDarray([[[[0.5525, 0.5434], + [0.9477, 0.9503]], + + [[0.4165, 0.3924], + [0.3310, 0.3935]], + + [[0.1008, 0.1750], + [0.9030, 0.8579]]], + + + [[[0.0680, 0.4944], + [0.4114, 0.6669]], + + [[0.6423, 0.2625], + [0.5413, 0.2225]], + + [[0.0197, 0.5079], + [0.4739, 0.4387]]]], dtype=ht.float32, device=cpu:0, split=1) """ if not isinstance(a, DNDarray): raise TypeError(f"'a' must be a DNDarray, currently {type(a)}") @@ -2069,7 +2130,12 @@ def reshape(a: DNDarray, *shape: Union[int, Tuple[int, ...]], **kwargs) -> DNDar # check new_split parameter new_split = kwargs.get("new_split") if new_split is None: - new_split = orig_split + if orig_split is not None and len(shape) != a.ndim: + # dimensionality reduced or expanded + # set output split axis to 0 + new_split = 0 + else: + new_split = orig_split new_split = stride_tricks.sanitize_axis(shape, new_split) if not a.is_distributed(): @@ -4214,92 +4280,3 @@ def mpi_topk(a, b, mpi_type): MPI_TOPK = MPI.Op.Create(mpi_topk, commute=True) - - -def unfold(a: DNDarray, axis: int, size: int, step: int = 1): - """ - Returns a DNDarray which contains all slices of size `size` in the axis `axis`. - - Behaves like torch.Tensor.unfold for DNDarrays. [torch.Tensor.unfold](https://pytorch.org/docs/stable/generated/torch.Tensor.unfold.html) - - Parameters - ---------- - a : DNDarray - array to unfold - axis : int - axis in which unfolding happens - size : int - the size of each slice that is unfolded, must be greater than 1 - step : int - the step between each slice, must be at least 1 - - Example: - ``` - >>> x = ht.arange(1., 8) - >>> x - DNDarray([1., 2., 3., 4., 5., 6., 7.], dtype=ht.float32, device=cpu:0, split=e) - >>> ht.unfold(x, 0, 2, 1) - DNDarray([[1., 2.], - [2., 3.], - [3., 4.], - [4., 5.], - [5., 6.], - [6., 7.]], dtype=ht.float32, device=cpu:0, split=None) - >>> ht.unfold(x, 0, 2, 2) - DNDarray([[1., 2.], - [3., 4.], - [5., 6.]], dtype=ht.float32, device=cpu:0, split=None) - ``` - - Note - --------- - You have to make sure that every node has at least chunk size size-1 if the split axis of the array is the unfold axis. - """ - if step < 1: - raise ValueError("step must be >= 1.") - if size <= 1: - raise ValueError("size must be > 1.") - axis = stride_tricks.sanitize_axis(a.shape, axis) - if size > a.shape[axis]: - raise ValueError( - f"maximum size for DNDarray at axis {axis} is {a.shape[axis]} but size is {size}." - ) - - comm = a.comm - dev = a.device - tdev = dev.torch_device - - if a.split is None or comm.size == 1 or a.split != axis: # early out - ret = factories.array( - a.larray.unfold(axis, size, step), is_split=a.split, device=dev, comm=comm - ) - - return ret - else: # comm.size > 1 and split axis == unfold axis - # index range [0:sizedim-1-(size-1)] = [0:sizedim-size] - # --> size of axis: ceil((sizedim-size+1) / step) = floor(sizedim-size) / step)) + 1 - # ret_shape = (*a_shape[:axis], int((a_shape[axis]-size)/step) + 1, a_shape[axis+1:], size) - - if (size - 1 > a.lshape_map[:, axis]).any(): - raise RuntimeError("Chunk-size needs to be at least size - 1.") - a.get_halo(size - 1, prev=False) - - counts, displs = a.counts_displs() - displs = torch.tensor(displs, device=tdev) - - # min local index in unfold axis - min_index = ((displs[comm.rank] - 1) // step + 1) * step - displs[comm.rank] - if min_index >= a.lshape[axis] or ( - comm.rank == comm.size - 1 and min_index + size > a.lshape[axis] - ): - loc_unfold_shape = list(a.lshape) - loc_unfold_shape[axis] = 0 - ret_larray = torch.zeros((*loc_unfold_shape, size), device=tdev) - else: # unfold has local data - ret_larray = a.array_with_halos[ - axis * (slice(None, None, None),) + (slice(min_index, None, None), Ellipsis) - ].unfold(axis, size, step) - - ret = factories.array(ret_larray, is_split=axis, device=dev, comm=comm) - - return ret