diff --git a/TrainingExtensions/torch/src/python/aimet_torch/v2/quantization/tensor.py b/TrainingExtensions/torch/src/python/aimet_torch/v2/quantization/tensor.py index 7fb389219a8..ed6643b38b0 100644 --- a/TrainingExtensions/torch/src/python/aimet_torch/v2/quantization/tensor.py +++ b/TrainingExtensions/torch/src/python/aimet_torch/v2/quantization/tensor.py @@ -90,6 +90,14 @@ class QuantizedTensorBase(torch.Tensor): encoding: EncodingBase + _attr_descriptors = { + torch.Tensor.dtype.__get__, + torch.Tensor.device.__get__, + torch.Tensor.layout.__get__, + torch.Tensor.shape.__get__, + torch.Tensor.size, + } + _cast_ops = { torch.Tensor.half, torch.Tensor.float, @@ -298,6 +306,9 @@ def __torch_function__(cls, func, types, args=(), kwargs=None): # pylint: disabl return HANDLED_FUNCTIONS[func](*args, **kwargs) ret = super().__torch_function__(func, types, args, kwargs) + if func in cls._attr_descriptors: + return ret + self, *_ = args if not isinstance(self, QuantizedTensorBase): diff --git a/TrainingExtensions/torch/test/python/v2/quantization/test_tensor.py b/TrainingExtensions/torch/test/python/v2/quantization/test_tensor.py index 704d04107c0..beffdc8291d 100644 --- a/TrainingExtensions/torch/test/python/v2/quantization/test_tensor.py +++ b/TrainingExtensions/torch/test/python/v2/quantization/test_tensor.py @@ -687,3 +687,33 @@ def test_use_qtensor_arg_for_passthrough_op(self, qtensor_cls, callback, scale, for output in outputs: assert not isinstance(output, QuantizedTensorBase) assert not hasattr(output, 'encoding') + + @pytest.mark.parametrize('qtensor_cls', [QuantizedTensor, DequantizedTensor]) + def test_attribute_descriptor(self, qtensor_cls): + """ + Given: torch.Tensor and a quantized/dequantized tensor with same dtype, device, shape, ... + When: Access the following attributes + * dtype + * device + * layout + * shape + * size() + Then: The attributes from both tensors should be value-equal and type-equal + """ + tensor = torch.empty(10, 10) + qtensor = torch.empty(10, 10).as_subclass(qtensor_cls) + + assert tensor.dtype == qtensor.dtype + assert type(tensor.dtype) == type(qtensor.dtype) + + assert tensor.device == qtensor.device + assert type(tensor.device) == type(qtensor.device) + + assert tensor.layout == qtensor.layout + assert type(tensor.layout) == type(qtensor.layout) + + assert tensor.shape == qtensor.shape + assert type(tensor.shape) == type(qtensor.shape) + + assert tensor.size() == qtensor.size() + assert type(tensor.size()) == type(qtensor.size())