39
39
TORCH_VERSION_AT_LEAST_2_5 ,
40
40
_is_float8_type
41
41
)
42
+ import logging
43
+
44
+ logger = logging .getLogger (__name__ )
42
45
43
46
from torchao .float8 .float8_tensor import ScaledMMConfig
44
47
aten = torch .ops .aten
@@ -88,9 +91,28 @@ class QuantizedLinearNotImplementedError(NotImplementedError):
88
91
pass
89
92
90
93
91
- _QLINEAR_DISPATCH_TABLE = {}
92
- def _register_quantized_linear_dispatch (dispatch_condition , impl ):
93
- _QLINEAR_DISPATCH_TABLE [dispatch_condition ] = impl
94
+ _AQT_QLINEAR_DISPATCH_TABLE = {}
95
+ def register_aqt_quantized_linear_dispatch (dispatch_condition , impl ):
96
+ """Register a dispatch for quantized linear op with dispatch_condition function and impl function
97
+ both takes three arguments:
98
+ input_tensor: dimension is (M1, M2, ..., in_features)
99
+ weight_tensor: dimension is (out_features, in_features)
100
+ bias: dimension is (out_features,)
101
+ so that these can be shared by F.linear, aten.mm, aten.addmm dispatches
102
+
103
+ Args:
104
+ `dispatch_condition` (Callable[[torch.Tensor, torch.Tensor, torch.Tensor], bool]: the dispatch
105
+ condition for a specialized quantized linear implementation, e.g. bfloat16 activation + uint4 weight
106
+ `impl` (Callable[[torch.Tensor, torch.Tensor, torch.Tensor], torch.Tensor]: the specialized
107
+ quantized linear implementation
108
+ """
109
+ _AQT_QLINEAR_DISPATCH_TABLE [dispatch_condition ] = impl
110
+
111
+ def deregister_aqt_quantized_linear_dispatch (dispatch_condition ):
112
+ if dispatch_condition in _AQT_QLINEAR_DISPATCH_TABLE :
113
+ del _AQT_QLINEAR_DISPATCH_TABLE [dispatch_condition ]
114
+ else :
115
+ logger .warn (f"Attempting to remove non-existant dispatch condition { dispatch_condition } " )
94
116
95
117
class AffineQuantizedTensor (TorchAOBaseTensor ):
96
118
"""
@@ -189,7 +211,7 @@ def dequantize(self, output_dtype: Optional[torch.dtype] = None) -> torch.Tensor
189
211
190
212
@staticmethod
191
213
def _quantized_linear_op (input_tensor , weight_tensor , bias ):
192
- for dispatch_condition , impl in _QLINEAR_DISPATCH_TABLE .items ():
214
+ for dispatch_condition , impl in _AQT_QLINEAR_DISPATCH_TABLE .items ():
193
215
if dispatch_condition (input_tensor , weight_tensor , bias ):
194
216
return impl (input_tensor , weight_tensor , bias )
195
217
raise QuantizedLinearNotImplementedError ("No specialized dispatch found for quantized linear op" )
@@ -440,7 +462,7 @@ def extra_repr(self):
440
462
441
463
@dataclass (frozen = True )
442
464
class Float8LayoutType (LayoutType ):
443
- mm_config : Optional [ScaledMMConfig ]
465
+ mm_config : Optional [ScaledMMConfig ] = None
444
466
445
467
446
468
@register_layout_cls (PlainLayoutType )
@@ -598,13 +620,13 @@ def from_plain(
598
620
599
621
@register_layout_cls (Float8LayoutType )
600
622
class Float8AQTLayout (AQTLayout ):
601
- """
623
+ """
602
624
Layout storage class for float8 layout for affine quantized tensor
603
625
"""
604
626
float8_data : torch .Tensor
605
627
scale : torch .Tensor
606
628
transposed : bool
607
-
629
+
608
630
def __new__ (
609
631
cls ,
610
632
float8_data : torch .Tensor ,
@@ -639,7 +661,7 @@ def _apply_fn_to_data(self, fn):
639
661
fn (self .float8_data )
640
662
fn (self .scale )
641
663
return self
642
-
664
+
643
665
def to (self , * args , ** kwargs ):
644
666
kwargs = self ._get_to_kwargs (* args , ** kwargs )
645
667
return self .__class__ (
@@ -976,21 +998,6 @@ def _linear_int8_act_int8_weight_semi_structured_sparse_impl(input_tensor, weigh
976
998
y += bias
977
999
return y
978
1000
979
- # this is for the case when linear activation is quantized, but is not caught by the previous
980
- # conditions that expects a quantized activation, we just dequantize the activation so that
981
- # it can continue with the weight only quantization dispatches
982
- # NOTE: this is a fallback path that must be registered after all the implementations that expects
983
- # input tensor to be quantized
984
- def _linear_quantized_act_fallback_check (input_tensor , weight_tensor , bias ):
985
- return (
986
- isinstance (input_tensor , AffineQuantizedTensor )
987
- )
988
-
989
- def _linear_quantized_act_fallback_impl (input_tensor , weight_tensor , bias ):
990
- input_tensor = input_tensor .dequantize ()
991
- # dequantize activation and redispatch to F.linear
992
- return torch .nn .functional .linear (input_tensor , weight_tensor , bias )
993
-
994
1001
def _linear_bf16_act_uint4_weight_check (input_tensor , weight_tensor , bias ):
995
1002
return (
996
1003
# input is native bfloat16 tensor
@@ -1187,19 +1194,18 @@ def _linear_fp_act_fp8_weight_impl(
1187
1194
).reshape (out_shape )
1188
1195
1189
1196
1190
- def _register_quantized_linear_dispatches ():
1197
+ def _register_aqt_quantized_linear_dispatches ():
1191
1198
for dispatch_condition , impl in [
1192
1199
(_linear_int8_act_int8_weight_check , _linear_int8_act_int8_weight_impl ),
1193
1200
(_linear_int8_act_int8_weight_semi_structured_sparse_check , _linear_int8_act_int8_weight_semi_structured_sparse_impl ),
1194
1201
(_linear_fp_act_fp8_tensor_wise_weight_check , _linear_fp_act_fp8_weight_impl ),
1195
- (_linear_quantized_act_fallback_check , _linear_quantized_act_fallback_impl ),
1196
1202
(_linear_bf16_act_uint4_weight_check , _linear_bf16_act_uint4_weight_impl ),
1197
1203
(_linear_fp_act_int8_weight_check , _linear_fp_act_int8_weight_impl ),
1198
1204
(_linear_f16_act_fpx_weight_check , _linear_f16_act_fpx_weight_impl ),
1199
1205
]:
1200
- _register_quantized_linear_dispatch (dispatch_condition , impl )
1206
+ register_aqt_quantized_linear_dispatch (dispatch_condition , impl )
1201
1207
1202
- _register_quantized_linear_dispatches ()
1208
+ _register_aqt_quantized_linear_dispatches ()
1203
1209
1204
1210
@implements (torch .nn .functional .linear )
1205
1211
def _ (func , types , args , kwargs ):
@@ -1216,7 +1222,11 @@ def _(func, types, args, kwargs):
1216
1222
# make the branches easier to understand in `_quantized_linear_op`
1217
1223
try :
1218
1224
return weight_tensor ._quantized_linear_op (input_tensor , weight_tensor , bias )
1219
- except QuantizedLinearNotImplementedError :
1225
+ except QuantizedLinearNotImplementedError as e :
1226
+ # fallback path is only called when user did not specify a specfic quantized linear implementation with `layout_type.quantized_linear_impl`
1227
+ if isinstance (weight_tensor , AffineQuantizedTensor ) and hasattr (weight_tensor .layout_type , "quantized_linear_impl" ) and weight_tensor .layout_type .quantized_linear_impl is not None :
1228
+ raise e
1229
+
1220
1230
if isinstance (input_tensor , AffineQuantizedTensor ):
1221
1231
input_tensor = input_tensor .dequantize ()
1222
1232
if isinstance (weight_tensor , AffineQuantizedTensor ):
@@ -1239,7 +1249,11 @@ def _(func, types, args, kwargs):
1239
1249
try :
1240
1250
weight_tensor = weight_tensor .t ()
1241
1251
return weight_tensor ._quantized_linear_op (input_tensor , weight_tensor , bias )
1242
- except QuantizedLinearNotImplementedError :
1252
+ except QuantizedLinearNotImplementedError as e :
1253
+ # fallback path is only called when user did not specify a specfic quantized linear implementation with `layout_type.quantized_linear_impl`
1254
+ if isinstance (weight_tensor , AffineQuantizedTensor ) and hasattr (weight_tensor .layout_type , "quantized_linear_impl" ) and weight_tensor .layout_type .quantized_linear_impl is not None :
1255
+ raise e
1256
+
1243
1257
if isinstance (input_tensor , AffineQuantizedTensor ):
1244
1258
input_tensor = input_tensor .dequantize ()
1245
1259
if isinstance (weight_tensor , AffineQuantizedTensor ):
@@ -1259,7 +1273,11 @@ def _(func, types, args, kwargs):
1259
1273
try :
1260
1274
weight_tensor = weight_tensor .t ()
1261
1275
return weight_tensor ._quantized_linear_op (input_tensor , weight_tensor , bias )
1262
- except QuantizedLinearNotImplementedError :
1276
+ except QuantizedLinearNotImplementedError as e :
1277
+ # fallback path is only called when user did not specify a specfic quantized linear implementation with `layout_type.quantized_linear_impl`
1278
+ if isinstance (weight_tensor , AffineQuantizedTensor ) and hasattr (weight_tensor .layout_type , "quantized_linear_impl" ) and weight_tensor .layout_type .quantized_linear_impl is not None :
1279
+ raise e
1280
+
1263
1281
if isinstance (input_tensor , AffineQuantizedTensor ):
1264
1282
input_tensor = input_tensor .dequantize ()
1265
1283
if isinstance (weight_tensor , AffineQuantizedTensor ):
0 commit comments