Skip to content

Commit e15e509

Browse files
authored
Allow quantized linear registration in a different file (#783)
* Allow quantized linear registration in a different file Summary: Previously there was some ordering that we need to maintain for quantized linear dispatch table in AffineQuantizedTensor, the reason is there is a fallback entry that dequantizes the input: https://github.com/pytorch/ao/blob/ba2d3b1333b90ccd0186216649a1c58c6a17ce56/torchao/dtypes/affine_quantized_tensor.py#L1195 so the dispatches with two inputs quantized (static or dynamic quantization) must come before this entry and dispatches with weight only quantization, however the fallback is not really used/needed in practice, since people typically just want to call into a very specific kernel. From offline discussions with @drisspg and @HDCharles, it might be useful to have a "quantized_linear_impl" for `LayoutType`, this allows people to specify and check which quantized_linear_impl they want to use to make sure they can call into the specific kernel, when this field is set, we'll not run the fallback path for quantized linear either (dequantize all activation and weight tensors and run the floating point linear op) I think this can be added for a specific layout type if people want to and we don't have to enforce this in the base `LayoutType` Test Plan: python test/dtypes/test_affine_quantized.py -k test_register_new_dispatch Reviewers: Subscribers: Tasks: Tags: * fix error * de-register dispatch * make register/deregister fn public * rebase and fix error
1 parent e2dad4a commit e15e509

File tree

4 files changed

+94
-32
lines changed

4 files changed

+94
-32
lines changed

test/dtypes/test_affine_quantized.py

+38
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,44 @@ def test_to_device(self, apply_quant):
8787
ql = apply_quant(l)
8888
ql.cuda()
8989

90+
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
91+
def test_register_new_dispatch(self):
92+
from torchao.dtypes.affine_quantized_tensor import (
93+
register_aqt_quantized_linear_dispatch,
94+
deregister_aqt_quantized_linear_dispatch,
95+
)
96+
from torchao.dtypes import to_affine_quantized_intx
97+
from torchao.dtypes import AffineQuantizedTensor
98+
from torchao.quantization.quant_primitives import MappingType
99+
100+
def dispatch_condition(input_tensor, weight_tensor, bias):
101+
return (
102+
isinstance(weight_tensor, AffineQuantizedTensor) and
103+
weight_tensor.quant_min == 0 and
104+
weight_tensor.quant_max == 2**6-1
105+
)
106+
107+
def impl(input_tensor, weight_tensor, bias):
108+
# this is just for testing, normally people will call into uint6 weight only
109+
# quantized linear operator here
110+
assert False, "dispatching to my impl for uint6 weight only quant"
111+
112+
register_aqt_quantized_linear_dispatch(dispatch_condition, impl)
113+
114+
def apply_uint6_weight_only_quant(linear):
115+
linear.weight = torch.nn.Parameter(to_affine_quantized_intx(linear.weight, MappingType.ASYMMETRIC, (1, linear.weight.shape[-1]), torch.uint8, 0, 2**6-1), requires_grad=False)
116+
return linear
117+
118+
l = torch.nn.Linear(128, 256, dtype=torch.bfloat16, device="cuda")
119+
apply_uint6_weight_only_quant(l)
120+
121+
example_input = torch.randn(1, 128, dtype=torch.bfloat16, device="cuda")
122+
with self.assertRaisesRegex(AssertionError, "dispatching to my impl for uint6 weight only quant"):
123+
l(example_input)
124+
125+
deregister_aqt_quantized_linear_dispatch(dispatch_condition)
126+
127+
90128

91129
common_utils.instantiate_parametrized_tests(TestAffineQuantized)
92130

torchao/dtypes/affine_quantized_tensor.py

+48-30
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,9 @@
3939
TORCH_VERSION_AT_LEAST_2_5,
4040
_is_float8_type
4141
)
42+
import logging
43+
44+
logger = logging.getLogger(__name__)
4245

4346
from torchao.float8.float8_tensor import ScaledMMConfig
4447
aten = torch.ops.aten
@@ -88,9 +91,28 @@ class QuantizedLinearNotImplementedError(NotImplementedError):
8891
pass
8992

9093

91-
_QLINEAR_DISPATCH_TABLE = {}
92-
def _register_quantized_linear_dispatch(dispatch_condition, impl):
93-
_QLINEAR_DISPATCH_TABLE[dispatch_condition] = impl
94+
_AQT_QLINEAR_DISPATCH_TABLE = {}
95+
def register_aqt_quantized_linear_dispatch(dispatch_condition, impl):
96+
"""Register a dispatch for quantized linear op with dispatch_condition function and impl function
97+
both takes three arguments:
98+
input_tensor: dimension is (M1, M2, ..., in_features)
99+
weight_tensor: dimension is (out_features, in_features)
100+
bias: dimension is (out_features,)
101+
so that these can be shared by F.linear, aten.mm, aten.addmm dispatches
102+
103+
Args:
104+
`dispatch_condition` (Callable[[torch.Tensor, torch.Tensor, torch.Tensor], bool]: the dispatch
105+
condition for a specialized quantized linear implementation, e.g. bfloat16 activation + uint4 weight
106+
`impl` (Callable[[torch.Tensor, torch.Tensor, torch.Tensor], torch.Tensor]: the specialized
107+
quantized linear implementation
108+
"""
109+
_AQT_QLINEAR_DISPATCH_TABLE[dispatch_condition] = impl
110+
111+
def deregister_aqt_quantized_linear_dispatch(dispatch_condition):
112+
if dispatch_condition in _AQT_QLINEAR_DISPATCH_TABLE:
113+
del _AQT_QLINEAR_DISPATCH_TABLE[dispatch_condition]
114+
else:
115+
logger.warn(f"Attempting to remove non-existant dispatch condition {dispatch_condition}")
94116

95117
class AffineQuantizedTensor(TorchAOBaseTensor):
96118
"""
@@ -189,7 +211,7 @@ def dequantize(self, output_dtype: Optional[torch.dtype] = None) -> torch.Tensor
189211

190212
@staticmethod
191213
def _quantized_linear_op(input_tensor, weight_tensor, bias):
192-
for dispatch_condition, impl in _QLINEAR_DISPATCH_TABLE.items():
214+
for dispatch_condition, impl in _AQT_QLINEAR_DISPATCH_TABLE.items():
193215
if dispatch_condition(input_tensor, weight_tensor, bias):
194216
return impl(input_tensor, weight_tensor, bias)
195217
raise QuantizedLinearNotImplementedError("No specialized dispatch found for quantized linear op")
@@ -440,7 +462,7 @@ def extra_repr(self):
440462

441463
@dataclass(frozen=True)
442464
class Float8LayoutType(LayoutType):
443-
mm_config: Optional[ScaledMMConfig]
465+
mm_config: Optional[ScaledMMConfig] = None
444466

445467

446468
@register_layout_cls(PlainLayoutType)
@@ -598,13 +620,13 @@ def from_plain(
598620

599621
@register_layout_cls(Float8LayoutType)
600622
class Float8AQTLayout(AQTLayout):
601-
"""
623+
"""
602624
Layout storage class for float8 layout for affine quantized tensor
603625
"""
604626
float8_data: torch.Tensor
605627
scale: torch.Tensor
606628
transposed: bool
607-
629+
608630
def __new__(
609631
cls,
610632
float8_data: torch.Tensor,
@@ -639,7 +661,7 @@ def _apply_fn_to_data(self, fn):
639661
fn(self.float8_data)
640662
fn(self.scale)
641663
return self
642-
664+
643665
def to(self, *args, **kwargs):
644666
kwargs = self._get_to_kwargs(*args, **kwargs)
645667
return self.__class__(
@@ -976,21 +998,6 @@ def _linear_int8_act_int8_weight_semi_structured_sparse_impl(input_tensor, weigh
976998
y += bias
977999
return y
9781000

979-
# this is for the case when linear activation is quantized, but is not caught by the previous
980-
# conditions that expects a quantized activation, we just dequantize the activation so that
981-
# it can continue with the weight only quantization dispatches
982-
# NOTE: this is a fallback path that must be registered after all the implementations that expects
983-
# input tensor to be quantized
984-
def _linear_quantized_act_fallback_check(input_tensor, weight_tensor, bias):
985-
return (
986-
isinstance(input_tensor, AffineQuantizedTensor)
987-
)
988-
989-
def _linear_quantized_act_fallback_impl(input_tensor, weight_tensor, bias):
990-
input_tensor = input_tensor.dequantize()
991-
# dequantize activation and redispatch to F.linear
992-
return torch.nn.functional.linear(input_tensor, weight_tensor, bias)
993-
9941001
def _linear_bf16_act_uint4_weight_check(input_tensor, weight_tensor, bias):
9951002
return (
9961003
# input is native bfloat16 tensor
@@ -1187,19 +1194,18 @@ def _linear_fp_act_fp8_weight_impl(
11871194
).reshape(out_shape)
11881195

11891196

1190-
def _register_quantized_linear_dispatches():
1197+
def _register_aqt_quantized_linear_dispatches():
11911198
for dispatch_condition, impl in [
11921199
(_linear_int8_act_int8_weight_check, _linear_int8_act_int8_weight_impl),
11931200
(_linear_int8_act_int8_weight_semi_structured_sparse_check, _linear_int8_act_int8_weight_semi_structured_sparse_impl),
11941201
(_linear_fp_act_fp8_tensor_wise_weight_check, _linear_fp_act_fp8_weight_impl),
1195-
(_linear_quantized_act_fallback_check, _linear_quantized_act_fallback_impl),
11961202
(_linear_bf16_act_uint4_weight_check, _linear_bf16_act_uint4_weight_impl),
11971203
(_linear_fp_act_int8_weight_check, _linear_fp_act_int8_weight_impl),
11981204
(_linear_f16_act_fpx_weight_check, _linear_f16_act_fpx_weight_impl),
11991205
]:
1200-
_register_quantized_linear_dispatch(dispatch_condition, impl)
1206+
register_aqt_quantized_linear_dispatch(dispatch_condition, impl)
12011207

1202-
_register_quantized_linear_dispatches()
1208+
_register_aqt_quantized_linear_dispatches()
12031209

12041210
@implements(torch.nn.functional.linear)
12051211
def _(func, types, args, kwargs):
@@ -1216,7 +1222,11 @@ def _(func, types, args, kwargs):
12161222
# make the branches easier to understand in `_quantized_linear_op`
12171223
try:
12181224
return weight_tensor._quantized_linear_op(input_tensor, weight_tensor, bias)
1219-
except QuantizedLinearNotImplementedError:
1225+
except QuantizedLinearNotImplementedError as e:
1226+
# fallback path is only called when user did not specify a specfic quantized linear implementation with `layout_type.quantized_linear_impl`
1227+
if isinstance(weight_tensor, AffineQuantizedTensor) and hasattr(weight_tensor.layout_type, "quantized_linear_impl") and weight_tensor.layout_type.quantized_linear_impl is not None:
1228+
raise e
1229+
12201230
if isinstance(input_tensor, AffineQuantizedTensor):
12211231
input_tensor = input_tensor.dequantize()
12221232
if isinstance(weight_tensor, AffineQuantizedTensor):
@@ -1239,7 +1249,11 @@ def _(func, types, args, kwargs):
12391249
try:
12401250
weight_tensor = weight_tensor.t()
12411251
return weight_tensor._quantized_linear_op(input_tensor, weight_tensor, bias)
1242-
except QuantizedLinearNotImplementedError:
1252+
except QuantizedLinearNotImplementedError as e:
1253+
# fallback path is only called when user did not specify a specfic quantized linear implementation with `layout_type.quantized_linear_impl`
1254+
if isinstance(weight_tensor, AffineQuantizedTensor) and hasattr(weight_tensor.layout_type, "quantized_linear_impl") and weight_tensor.layout_type.quantized_linear_impl is not None:
1255+
raise e
1256+
12431257
if isinstance(input_tensor, AffineQuantizedTensor):
12441258
input_tensor = input_tensor.dequantize()
12451259
if isinstance(weight_tensor, AffineQuantizedTensor):
@@ -1259,7 +1273,11 @@ def _(func, types, args, kwargs):
12591273
try:
12601274
weight_tensor = weight_tensor.t()
12611275
return weight_tensor._quantized_linear_op(input_tensor, weight_tensor, bias)
1262-
except QuantizedLinearNotImplementedError:
1276+
except QuantizedLinearNotImplementedError as e:
1277+
# fallback path is only called when user did not specify a specfic quantized linear implementation with `layout_type.quantized_linear_impl`
1278+
if isinstance(weight_tensor, AffineQuantizedTensor) and hasattr(weight_tensor.layout_type, "quantized_linear_impl") and weight_tensor.layout_type.quantized_linear_impl is not None:
1279+
raise e
1280+
12631281
if isinstance(input_tensor, AffineQuantizedTensor):
12641282
input_tensor = input_tensor.dequantize()
12651283
if isinstance(weight_tensor, AffineQuantizedTensor):

torchao/dtypes/utils.py

+7-1
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import torch
2-
from typing import Dict, Callable, Union, Tuple
2+
from typing import Dict, Callable, Union, Tuple, Optional
33
from collections import defaultdict
44
import functools
55
from dataclasses import dataclass
@@ -73,6 +73,12 @@ class MyTensor(torch.Tensor):
7373

7474
"""
7575
Base class for different LayoutType, should not be instantiated directly
76+
used to allow users to pass around configurations for the layout tensor, e.g. inner_k_tiles
77+
for int4 tensor core tiled layout
78+
79+
Note: layout is an abstraction not only for custom data representation, it is also used for how the
80+
layout interacts with different operators, e.g. the same data representation can have different
81+
behaviors when running the same operator, e.g. transpose, quantized_linear.
7682
"""
7783
@dataclass(frozen=True)
7884
class LayoutType:

torchao/quantization/quant_api.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -498,7 +498,7 @@ def int8_dynamic_activation_int8_semi_sparse_weight():
498498
def float8_weight_only(weight_dtype: torch.dtype = torch.float8_e4m3fn):
499499
"""
500500
Applies float8 weight-only symmetric per-channel quantization to linear layers.
501-
501+
502502
Args:
503503
weight_dtype (torch.dtype): The target data type for weight quantization. Default is torch.float8_e4m3fn.
504504

0 commit comments

Comments
 (0)