Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

caught empty list of partitions in xu.merge_partitions() #280

Merged
12 changes: 11 additions & 1 deletion tests/test_partitioning.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,10 +132,15 @@ def test_partition_roundtrip(self):
reordered = back.isel(mesh2d_nFaces=order)
assert reordered["face_z"].equals(self.uds["face_z"])

def test_merge_partition_single(self):
partitions = [self.uds]
back = pt.merge_partitions(partitions)
assert back == self.uds

def test_merge_partitions__errors(self):
partitions = self.uds.ugrid.partition(n_part=2)
with pytest.raises(TypeError, match="Expected UgridDataArray or UgridDataset"):
pt.merge_partitions(p.ugrid.obj for p in partitions)
pt.merge_partitions([p.ugrid.obj for p in partitions])

grid1 = partitions[1].ugrid.grid
partitions[1]["extra"] = (grid1.face_dimension, np.ones(grid1.n_face))
Expand All @@ -162,6 +167,11 @@ def test_merge_partitions__errors(self):
):
pt.merge_partitions(partitions)

with pytest.raises(
ValueError, match="Cannot merge partitions: zero partitions provided."
):
xu.merge_partitions([])

def test_merge_partitions_no_duplicates(self):
part1 = self.uds.isel(mesh2d_nFaces=[0, 1, 2, 3])
part2 = self.uds.isel(mesh2d_nFaces=[2, 3, 4, 5])
Expand Down
6 changes: 6 additions & 0 deletions xugrid/ugrid/partitioning.py
Original file line number Diff line number Diff line change
Expand Up @@ -328,6 +328,8 @@ def merge_partitions(partitions, merge_ugrid_chunks: bool = True):
-------
merged : UgridDataset
"""
if len(partitions) == 0:
raise ValueError("Cannot merge partitions: zero partitions provided.")
types = {type(obj) for obj in partitions}
msg = "Expected UgridDataArray or UgridDataset, received: {}"
if len(types) > 1:
Expand All @@ -337,6 +339,10 @@ def merge_partitions(partitions, merge_ugrid_chunks: bool = True):
if obj_type not in (UgridDataArray, UgridDataset):
raise TypeError(msg.format(obj_type.__name__))

# return first partition if single partition is provided
if len(partitions) == 1:
return next(iter(partitions))

# Collect grids
grids = [grid for p in partitions for grid in p.grids]
ugrid_dims = {dim for grid in grids for dim in grid.dimensions}
Expand Down
Loading