diff --git a/torchao/quantization/subclass.py b/torchao/quantization/subclass.py index bc40ffeaff..f3a1074ba5 100644 --- a/torchao/quantization/subclass.py +++ b/torchao/quantization/subclass.py @@ -670,10 +670,9 @@ def __new__( quant_max: Optional[int] = None, zero_point_domain: ZeroPointDomain = ZeroPointDomain.INT, dtype=None, - # TODO: remove args and kwargs - *args, - **kwargs + strides=None, ): + kwargs = {} kwargs["device"] = int_data.device kwargs["layout"] = ( kwargs.get("layout") if kwargs.get("layout", False) else int_data.layout @@ -681,7 +680,8 @@ def __new__( if dtype is None: dtype = scale.dtype kwargs["dtype"] = dtype - assert not kwargs.get("requires_grad", False) + if strides is not None: + kwargs["strides"] = strides kwargs["requires_grad"] = False return torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs) # type: ignore[attr-defined] @@ -696,8 +696,7 @@ def __init__( quant_max: Optional[int] = None, zero_point_domain: ZeroPointDomain = ZeroPointDomain.INT, dtype=None, - *args, - **kwargs + strides=None, ): self.int_data = int_data self.scale = scale @@ -912,6 +911,7 @@ def _apply_fn_to_data(self, fn): self.quant_max, self.zero_point_domain, dtype=self.dtype, + strides=self.stride(), ) @classmethod