We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
1 parent 4241803 commit 0833b24Copy full SHA for 0833b24
torch_xla/distributed/spmd/xla_sharding.py
@@ -65,11 +65,11 @@ def __init__(self,
65
assert (len(device_ids) == np.prod(mesh_shape))
66
# device ids are unique
67
assert len(device_ids) == len(np.unique(device_ids))
68
- # device ids are continous
69
- assert all(d < self.size() for d in device_ids - np.min(device_ids))
70
self.device_ids = device_ids
71
self.mesh_shape = mesh_shape
72
self.axis_names = axis_names
+ # device ids are continous
+ assert all(d < self.size() for d in device_ids - np.min(device_ids))
73
74
def size(self):
75
return np.prod(self.mesh_shape)
0 commit comments