Skip to content

Commit

Permalink
Add special handling for attribute descriptors of QuantizedTensorBase (
Browse files Browse the repository at this point in the history
…quic#3519)

Signed-off-by: Kyunggeun Lee <[email protected]>
  • Loading branch information
quic-kyunggeu authored Nov 20, 2024
1 parent a898440 commit d6fb17b
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,14 @@ class QuantizedTensorBase(torch.Tensor):

encoding: EncodingBase

_attr_descriptors = {
torch.Tensor.dtype.__get__,
torch.Tensor.device.__get__,
torch.Tensor.layout.__get__,
torch.Tensor.shape.__get__,
torch.Tensor.size,
}

_cast_ops = {
torch.Tensor.half,
torch.Tensor.float,
Expand Down Expand Up @@ -298,6 +306,9 @@ def __torch_function__(cls, func, types, args=(), kwargs=None): # pylint: disabl
return HANDLED_FUNCTIONS[func](*args, **kwargs)
ret = super().__torch_function__(func, types, args, kwargs)

if func in cls._attr_descriptors:
return ret

self, *_ = args

if not isinstance(self, QuantizedTensorBase):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -687,3 +687,33 @@ def test_use_qtensor_arg_for_passthrough_op(self, qtensor_cls, callback, scale,
for output in outputs:
assert not isinstance(output, QuantizedTensorBase)
assert not hasattr(output, 'encoding')

@pytest.mark.parametrize('qtensor_cls', [QuantizedTensor, DequantizedTensor])
def test_attribute_descriptor(self, qtensor_cls):
"""
Given: torch.Tensor and a quantized/dequantized tensor with same dtype, device, shape, ...
When: Access the following attributes
* dtype
* device
* layout
* shape
* size()
Then: The attributes from both tensors should be value-equal and type-equal
"""
tensor = torch.empty(10, 10)
qtensor = torch.empty(10, 10).as_subclass(qtensor_cls)

assert tensor.dtype == qtensor.dtype
assert type(tensor.dtype) == type(qtensor.dtype)

assert tensor.device == qtensor.device
assert type(tensor.device) == type(qtensor.device)

assert tensor.layout == qtensor.layout
assert type(tensor.layout) == type(qtensor.layout)

assert tensor.shape == qtensor.shape
assert type(tensor.shape) == type(qtensor.shape)

assert tensor.size() == qtensor.size()
assert type(tensor.size()) == type(qtensor.size())

0 comments on commit d6fb17b

Please sign in to comment.