diff --git a/merlin/dtypes/shape.py b/merlin/dtypes/shape.py index db37b643f..c7fec05a7 100644 --- a/merlin/dtypes/shape.py +++ b/merlin/dtypes/shape.py @@ -153,7 +153,7 @@ def __post_init__(self): if self.dims is not None: new_dims = [] for i, dim in enumerate(self.dims): - if i == 0 and dim is None: + if i == 0 and (dim is None or dim == (0, None)): dim = -1 new_dim = Dimension(dim) new_dims.append(new_dim)