Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support multiple axes for ht.percentile #1510

Merged
merged 42 commits into from
Sep 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
42 commits
Select commit Hold shift + click to select a range
ca34caa
refactor percentile
ClaudiaComito Apr 11, 2024
d4861a8
support percentile along multiple axes
ClaudiaComito Apr 11, 2024
54dcd4f
support tuple axis for expand_dims
ClaudiaComito Apr 12, 2024
0e9e587
refactor sanitation, data manipulation
ClaudiaComito Apr 12, 2024
d3c6c25
fix transpose for tuple axis
ClaudiaComito Apr 12, 2024
191a42d
adjust indices shapes for bin op, rule out complex types
ClaudiaComito Apr 13, 2024
def9c48
skip out sanitation tests, test for complex types
ClaudiaComito Apr 13, 2024
5db0b8b
Merge branch 'main' into features/1389-Speed_up_ht_percentile
ClaudiaComito Apr 13, 2024
799a4fc
Merge branch 'main' into features/1389-Speed_up_ht_percentile
ClaudiaComito Jun 4, 2024
e496515
sanitize output buffer
ClaudiaComito Jun 5, 2024
f04fc1b
sanizite output buffer
ClaudiaComito Jun 5, 2024
7ff48fa
update tests
ClaudiaComito Jun 5, 2024
610b035
Merge branch 'main' into features/1389-Speed_up_ht_percentile
ClaudiaComito Jun 5, 2024
83816a2
Merge branch 'main' into features/1389-Speed_up_ht_percentile
ClaudiaComito Jun 7, 2024
79ecbb7
Merge branch 'main' into features/1389-Speed_up_ht_percentile
ClaudiaComito Jun 10, 2024
6bf5924
expand tests
ClaudiaComito Jun 10, 2024
90d952b
expand tests
ClaudiaComito Jun 10, 2024
01259fc
Merge branch 'features/1389-Speed_up_ht_percentile' of github.com:hel…
ClaudiaComito Jun 10, 2024
0b613ba
Merge branch 'main' into features/1389-Speed_up_ht_percentile
ClaudiaComito Jul 12, 2024
f073cd7
edit docs
ClaudiaComito Jul 12, 2024
53b1888
Merge branch 'main' into features/1389-Speed_up_ht_percentile
ClaudiaComito Jul 17, 2024
eb52d27
fix split when len(new_shape) < a.ndim
ClaudiaComito Jul 19, 2024
fa5d32f
test sketched percentile with tuple axis
ClaudiaComito Jul 19, 2024
a7f5cef
small edits
ClaudiaComito Jul 19, 2024
3b81754
Update heat/core/statistics.py
ClaudiaComito Jul 19, 2024
2053ca7
Update heat/core/statistics.py
ClaudiaComito Jul 19, 2024
c64be8f
fix split after reshape
ClaudiaComito Jul 19, 2024
3a8f8a6
Merge branch 'features/1389-Speed_up_ht_percentile' of github.com:hel…
ClaudiaComito Jul 19, 2024
81ca829
Merge branch 'main' into features/1389-Speed_up_ht_percentile
mrfh92 Jul 19, 2024
75150dd
Merge branch 'main' into features/1389-Speed_up_ht_percentile
mrfh92 Jul 19, 2024
2470965
test out buffer
ClaudiaComito Aug 15, 2024
3eb0f6f
Merge branch 'main' into features/1389-Speed_up_ht_percentile
ClaudiaComito Aug 27, 2024
4f8a5b3
fix split mismatch w. output buffer
ClaudiaComito Sep 3, 2024
c21f9f1
Merge branch 'main' into features/1389-Speed_up_ht_percentile
ClaudiaComito Sep 3, 2024
472d4c8
set split to 0 when reshaping to different ndims
ClaudiaComito Sep 4, 2024
8ab6636
Merge branch 'features/1389-Speed_up_ht_percentile' of github.com:hel…
ClaudiaComito Sep 4, 2024
48e23eb
refine calc of output split
ClaudiaComito Sep 4, 2024
589e81d
edits
ClaudiaComito Sep 4, 2024
e5eb012
restore manipulations from main
ClaudiaComito Sep 4, 2024
eb39460
Update manipulations
ClaudiaComito Sep 4, 2024
d1175e8
adding back unfold(), no idea what happened
ClaudiaComito Sep 4, 2024
e77f1ff
small edits to trigger pipeline
ClaudiaComito Sep 5, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
259 changes: 161 additions & 98 deletions heat/core/manipulations.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,10 +58,10 @@
"swapaxes",
"tile",
"topk",
"unfold",
"unique",
"vsplit",
"vstack",
"unfold",
]


Expand Down Expand Up @@ -963,14 +963,36 @@ def expand_dims(a: DNDarray, axis: int) -> DNDarray:
# sanitize input
sanitation.sanitize_in(a)

# sanitize axis, introduce arbitrary dummy dimension to model expansion
axis = stride_tricks.sanitize_axis(a.shape + (1,), axis)
# track split axis
split_bookkeeping = [None] * a.ndim
if a.split is not None:
split_bookkeeping[a.split] = "split"
output_shape = list(a.shape)

local_expansion = a.larray
if isinstance(axis, (tuple, list)):
# sanitize axis, introduce arbitrary dummy dimensions to model expansion
axis = stride_tricks.sanitize_axis(a.shape + (1,) * len(axis), axis)
for ax in axis:
split_bookkeeping.insert(ax, None)
output_shape.insert(ax, 1)
local_expansion = local_expansion.unsqueeze(dim=ax)

else:
# sanitize axis, introduce arbitrary dummy dimensions to model expansion
axis = stride_tricks.sanitize_axis(a.shape + (1,), axis)
split_bookkeeping.insert(axis, None)
output_shape.insert(axis, 1)
local_expansion = local_expansion.unsqueeze(dim=axis)

output_split = split_bookkeeping.index("split") if "split" in split_bookkeeping else None
output_shape = tuple(output_shape)

return DNDarray(
a.larray.unsqueeze(dim=axis),
a.shape[:axis] + (1,) + a.shape[axis:],
local_expansion,
output_shape,
a.dtype,
a.split if a.split is None or a.split < axis else a.split + 1,
output_split,
a.device,
a.comm,
a.balanced,
Expand Down Expand Up @@ -2005,7 +2027,9 @@ def reshape(a: DNDarray, *shape: Union[int, Tuple[int, ...]], **kwargs) -> DNDar
Shape of the new array. Must be compatible with the original shape. If an integer, then the result will be a 1-D array of that length.
One shape dimension can be -1. In this case, the value is inferred from the length of the array and remaining dimensions.
new_split : int, optional
The distribution axis of the reshaped array. Default: None (same distribution axis as `a`).
The distribution axis of the reshaped array. If `new_split` is not provided, the reshaped array will have:
- the same split axis as the input array, if the original dimensionality is unchanged;
- split axis 0, if the number of dimensions is modified by reshaping.

Raises
------
Expand All @@ -2014,7 +2038,7 @@ def reshape(a: DNDarray, *shape: Union[int, Tuple[int, ...]], **kwargs) -> DNDar

Notes
-----
`reshape()` might require significant communication among processes. Operating along split axis 0 is recommended.
`reshape()` might require significant communication among processes. Communication is minimized if the input array is distributed along axis 0, i.e. `a.split == 0`.

See Also
--------
Expand All @@ -2032,6 +2056,44 @@ def reshape(a: DNDarray, *shape: Union[int, Tuple[int, ...]], **kwargs) -> DNDar
>>> ht.reshape(a, (2,4))
(1/2) tensor([[0., 2., 4., 6.]])
(2/2) tensor([[ 8., 10., 12., 14.]])
# 3-dim array, distributed along axis 1
>>> a = ht.random.rand(2, 3, 4, split=1)
>>> a
DNDarray([[[0.5525, 0.5434, 0.9477, 0.9503],
[0.4165, 0.3924, 0.3310, 0.3935],
[0.1008, 0.1750, 0.9030, 0.8579]],

[[0.0680, 0.4944, 0.4114, 0.6669],
[0.6423, 0.2625, 0.5413, 0.2225],
[0.0197, 0.5079, 0.4739, 0.4387]]], dtype=ht.float32, device=cpu:0, split=1)
>>> a.reshape(-1, 3) # reshape to 2-dim array: split axis will be set to 0
DNDarray([[0.5525, 0.5434, 0.9477],
[0.9503, 0.4165, 0.3924],
[0.3310, 0.3935, 0.1008],
[0.1750, 0.9030, 0.8579],
[0.0680, 0.4944, 0.4114],
[0.6669, 0.6423, 0.2625],
[0.5413, 0.2225, 0.0197],
[0.5079, 0.4739, 0.4387]], dtype=ht.float32, device=cpu:0, split=0)
>>> a.reshape(2,3,2,2, new_split=1) # reshape to 4-dim array, specify distribution axis
DNDarray([[[[0.5525, 0.5434],
[0.9477, 0.9503]],

[[0.4165, 0.3924],
[0.3310, 0.3935]],

[[0.1008, 0.1750],
[0.9030, 0.8579]]],


[[[0.0680, 0.4944],
[0.4114, 0.6669]],

[[0.6423, 0.2625],
[0.5413, 0.2225]],

[[0.0197, 0.5079],
[0.4739, 0.4387]]]], dtype=ht.float32, device=cpu:0, split=1)
"""
if not isinstance(a, DNDarray):
raise TypeError(f"'a' must be a DNDarray, currently {type(a)}")
Expand Down Expand Up @@ -2069,7 +2131,12 @@ def reshape(a: DNDarray, *shape: Union[int, Tuple[int, ...]], **kwargs) -> DNDar
# check new_split parameter
new_split = kwargs.get("new_split")
if new_split is None:
new_split = orig_split
if orig_split is not None and len(shape) != a.ndim:
# dimensionality reduced or expanded
# set output split axis to 0
new_split = 0
else:
new_split = orig_split
new_split = stride_tricks.sanitize_axis(shape, new_split)

if not a.is_distributed():
Expand Down Expand Up @@ -3414,6 +3481,91 @@ def unique(
DNDarray.unique.__doc__ = unique.__doc__


def unfold(a: DNDarray, axis: int, size: int, step: int = 1):
"""
Returns a DNDarray which contains all slices of size `size` in the axis `axis`.
Behaves like torch.Tensor.unfold for DNDarrays. [torch.Tensor.unfold](https://pytorch.org/docs/stable/generated/torch.Tensor.unfold.html)
Parameters
----------
a : DNDarray
array to unfold
axis : int
axis in which unfolding happens
size : int
the size of each slice that is unfolded, must be greater than 1
step : int
the step between each slice, must be at least 1
Example:
```
>>> x = ht.arange(1., 8)
>>> x
DNDarray([1., 2., 3., 4., 5., 6., 7.], dtype=ht.float32, device=cpu:0, split=e)
>>> ht.unfold(x, 0, 2, 1)
DNDarray([[1., 2.],
[2., 3.],
[3., 4.],
[4., 5.],
[5., 6.],
[6., 7.]], dtype=ht.float32, device=cpu:0, split=None)
>>> ht.unfold(x, 0, 2, 2)
DNDarray([[1., 2.],
[3., 4.],
[5., 6.]], dtype=ht.float32, device=cpu:0, split=None)
```
Note
---------
You have to make sure that every node has at least chunk size size-1 if the split axis of the array is the unfold axis.
"""
if step < 1:
raise ValueError("step must be >= 1.")
if size <= 1:
raise ValueError("size must be > 1.")
axis = stride_tricks.sanitize_axis(a.shape, axis)
if size > a.shape[axis]:
raise ValueError(
f"maximum size for DNDarray at axis {axis} is {a.shape[axis]} but size is {size}."
)

comm = a.comm
dev = a.device
tdev = dev.torch_device

if a.split is None or comm.size == 1 or a.split != axis: # early out
ret = factories.array(
a.larray.unfold(axis, size, step), is_split=a.split, device=dev, comm=comm
)

return ret
else: # comm.size > 1 and split axis == unfold axis
# index range [0:sizedim-1-(size-1)] = [0:sizedim-size]
# --> size of axis: ceil((sizedim-size+1) / step) = floor(sizedim-size) / step)) + 1
# ret_shape = (*a_shape[:axis], int((a_shape[axis]-size)/step) + 1, a_shape[axis+1:], size)

if (size - 1 > a.lshape_map[:, axis]).any():
raise RuntimeError("Chunk-size needs to be at least size - 1.")
a.get_halo(size - 1, prev=False)

counts, displs = a.counts_displs()
displs = torch.tensor(displs, device=tdev)

# min local index in unfold axis
min_index = ((displs[comm.rank] - 1) // step + 1) * step - displs[comm.rank]
if min_index >= a.lshape[axis] or (
comm.rank == comm.size - 1 and min_index + size > a.lshape[axis]
):
loc_unfold_shape = list(a.lshape)
loc_unfold_shape[axis] = 0
ret_larray = torch.zeros((*loc_unfold_shape, size), device=tdev)
else: # unfold has local data
ret_larray = a.array_with_halos[
axis * (slice(None, None, None),) + (slice(min_index, None, None), Ellipsis)
].unfold(axis, size, step)

ret = factories.array(ret_larray, is_split=axis, device=dev, comm=comm)

return ret


def vsplit(x: DNDarray, indices_or_sections: Iterable) -> List[DNDarray, ...]:
"""
Split array into multiple sub-DNDNarrays along the 1st axis (vertically/row-wise).
Expand Down Expand Up @@ -4214,92 +4366,3 @@ def mpi_topk(a, b, mpi_type):


MPI_TOPK = MPI.Op.Create(mpi_topk, commute=True)


def unfold(a: DNDarray, axis: int, size: int, step: int = 1):
"""
Returns a DNDarray which contains all slices of size `size` in the axis `axis`.

Behaves like torch.Tensor.unfold for DNDarrays. [torch.Tensor.unfold](https://pytorch.org/docs/stable/generated/torch.Tensor.unfold.html)

Parameters
----------
a : DNDarray
array to unfold
axis : int
axis in which unfolding happens
size : int
the size of each slice that is unfolded, must be greater than 1
step : int
the step between each slice, must be at least 1

Example:
```
>>> x = ht.arange(1., 8)
>>> x
DNDarray([1., 2., 3., 4., 5., 6., 7.], dtype=ht.float32, device=cpu:0, split=e)
>>> ht.unfold(x, 0, 2, 1)
DNDarray([[1., 2.],
[2., 3.],
[3., 4.],
[4., 5.],
[5., 6.],
[6., 7.]], dtype=ht.float32, device=cpu:0, split=None)
>>> ht.unfold(x, 0, 2, 2)
DNDarray([[1., 2.],
[3., 4.],
[5., 6.]], dtype=ht.float32, device=cpu:0, split=None)
```

Note
---------
You have to make sure that every node has at least chunk size size-1 if the split axis of the array is the unfold axis.
"""
if step < 1:
raise ValueError("step must be >= 1.")
if size <= 1:
raise ValueError("size must be > 1.")
axis = stride_tricks.sanitize_axis(a.shape, axis)
if size > a.shape[axis]:
raise ValueError(
f"maximum size for DNDarray at axis {axis} is {a.shape[axis]} but size is {size}."
)

comm = a.comm
dev = a.device
tdev = dev.torch_device

if a.split is None or comm.size == 1 or a.split != axis: # early out
ret = factories.array(
a.larray.unfold(axis, size, step), is_split=a.split, device=dev, comm=comm
)

return ret
else: # comm.size > 1 and split axis == unfold axis
# index range [0:sizedim-1-(size-1)] = [0:sizedim-size]
# --> size of axis: ceil((sizedim-size+1) / step) = floor(sizedim-size) / step)) + 1
# ret_shape = (*a_shape[:axis], int((a_shape[axis]-size)/step) + 1, a_shape[axis+1:], size)

if (size - 1 > a.lshape_map[:, axis]).any():
raise RuntimeError("Chunk-size needs to be at least size - 1.")
a.get_halo(size - 1, prev=False)

counts, displs = a.counts_displs()
displs = torch.tensor(displs, device=tdev)

# min local index in unfold axis
min_index = ((displs[comm.rank] - 1) // step + 1) * step - displs[comm.rank]
if min_index >= a.lshape[axis] or (
comm.rank == comm.size - 1 and min_index + size > a.lshape[axis]
):
loc_unfold_shape = list(a.lshape)
loc_unfold_shape[axis] = 0
ret_larray = torch.zeros((*loc_unfold_shape, size), device=tdev)
else: # unfold has local data
ret_larray = a.array_with_halos[
axis * (slice(None, None, None),) + (slice(min_index, None, None), Ellipsis)
].unfold(axis, size, step)

ret = factories.array(ret_larray, is_split=axis, device=dev, comm=comm)

return ret
Loading
Loading