Skip to content

Commit 0833b24

Browse files
committed
fix assertion
1 parent 4241803 commit 0833b24

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

torch_xla/distributed/spmd/xla_sharding.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -65,11 +65,11 @@ def __init__(self,
6565
assert (len(device_ids) == np.prod(mesh_shape))
6666
# device ids are unique
6767
assert len(device_ids) == len(np.unique(device_ids))
68-
# device ids are continous
69-
assert all(d < self.size() for d in device_ids - np.min(device_ids))
7068
self.device_ids = device_ids
7169
self.mesh_shape = mesh_shape
7270
self.axis_names = axis_names
71+
# device ids are continous
72+
assert all(d < self.size() for d in device_ids - np.min(device_ids))
7373

7474
def size(self):
7575
return np.prod(self.mesh_shape)

0 commit comments

Comments
 (0)