diff --git a/test/dtypes/test_affine_quantized.py b/test/dtypes/test_affine_quantized.py index d26f1d8e04..616701f1e3 100644 --- a/test/dtypes/test_affine_quantized.py +++ b/test/dtypes/test_affine_quantized.py @@ -218,6 +218,7 @@ def test_flatten_unflatten(self, device, dtype): linear = torch.nn.Linear(128, 256, dtype=dtype, device=device) if isinstance(apply_quant, AOBaseConfig): quantize_(linear, apply_quant) + ql = linear else: # TODO(#1690): delete this once config migration is done ql = apply_quant(linear) diff --git a/test/quantization/test_quant_api.py b/test/quantization/test_quant_api.py index e0f6cb1ace..4cb0ee3579 100644 --- a/test/quantization/test_quant_api.py +++ b/test/quantization/test_quant_api.py @@ -33,6 +33,7 @@ float8_dynamic_activation_float8_weight, float8_static_activation_float8_weight, float8_weight_only, + int4_dynamic_activation_int4_weight, int4_weight_only, int8_dynamic_activation_int4_weight, int8_dynamic_activation_int8_weight, @@ -50,6 +51,7 @@ TORCH_VERSION_AT_LEAST_2_5, TORCH_VERSION_AT_LEAST_2_6, is_sm_at_least_89, + is_sm_at_least_90, unwrap_tensor_subclass, ) @@ -798,6 +800,10 @@ def test_int4wo_cpu(self, dtype, x_dim): float8_weight_only(), float8_dynamic_activation_float8_weight(), float8_static_activation_float8_weight(scale=torch.tensor([1.0])), + int4_dynamic_activation_int4_weight(), + int8_dynamic_activation_int8_weight(), + int8_dynamic_activation_int4_weight(), + int8_weight_only(), ], ) def test_workflow_e2e_numerics(self, config): @@ -816,6 +822,11 @@ def test_workflow_e2e_numerics(self, config): and not is_sm_at_least_89() ): return unittest.skip("requires CUDA capability 8.9 or greater") + elif ( + isinstance(config, int4_dynamic_activation_int4_weight) + and is_sm_at_least_90() + ): + return unittest.skip("only supported on CUDA capability 8.9, not greater") # scale has to be moved to cuda here because the parametrization init # code happens before gating for cuda availability @@ -837,7 +848,7 @@ def test_workflow_e2e_numerics(self, config): y_q = m_q(x) sqnr = compute_error(y_ref, y_q) - assert sqnr >= 20, f"SQNR {sqnr} is too low" + assert sqnr >= 16.5, f"SQNR {sqnr} is too low" class TestMultiTensorFlow(TestCase): diff --git a/torchao/quantization/__init__.py b/torchao/quantization/__init__.py index ca9a4141fc..a1d8bda058 100644 --- a/torchao/quantization/__init__.py +++ b/torchao/quantization/__init__.py @@ -49,7 +49,11 @@ Float8DynamicActivationFloat8WeightConfig, Float8StaticActivationFloat8WeightConfig, Float8WeightOnlyConfig, + Int4DynamicActivationInt4WeightConfig, Int4WeightOnlyConfig, + Int8DynamicActivationInt4WeightConfig, + Int8DynamicActivationInt8WeightConfig, + Int8WeightOnlyConfig, float8_dynamic_activation_float8_weight, float8_static_activation_float8_weight, float8_weight_only, @@ -123,7 +127,11 @@ "fpx_weight_only", "gemlite_uintx_weight_only", "swap_conv2d_1x1_to_linear", + "Int4DynamicActivationInt4WeightConfig", + "Int8DynamicActivationInt4WeightConfig", + "Int8DynamicActivationInt8WeightConfig", "Int4WeightOnlyConfig", + "Int8WeightOnlyConfig", "Float8WeightOnlyConfig", "Float8DynamicActivationFloat8WeightConfig", "Float8StaticActivationFloat8WeightConfig", diff --git a/torchao/quantization/quant_api.py b/torchao/quantization/quant_api.py index 6e5e043fb0..60ee0384c9 100644 --- a/torchao/quantization/quant_api.py +++ b/torchao/quantization/quant_api.py @@ -43,6 +43,7 @@ to_affine_quantized_intx, to_marlinqqq_quantized_intx, ) +from torchao.dtypes.utils import Layout from torchao.float8.float8_linear import Float8Linear from torchao.float8.inference import Float8MMConfig from torchao.quantization.linear_activation_weight_observed_tensor import ( @@ -590,18 +591,45 @@ def _int8_symm_per_token_quant(x: torch.Tensor) -> torch.Tensor: ) -def apply_int8_dynamic_activation_int4_weight_quant( - weight, - group_size=32, - layout=PlainLayout(), - mapping_type=MappingType.SYMMETRIC, - act_mapping_type=MappingType.ASYMMETRIC, +@dataclass +class Int8DynamicActivationInt4WeightConfig(AOBaseConfig): + """Configuration for applying int8 dynamic per token asymmetric activation quantization and int4 per group weight symmetric quantization to linear + This is used to produce a model for executorch backend, but currently executorch did not + support lowering for the quantized model from this flow yet + + Args: + `group_size`: parameter for quantization, controls the granularity of quantization, smaller + size is more fine grained + `layout`: layout type for quantized weight tensor, only supports `MarlinQQQLayout()` and `CutlassInt4PackedLayout()` for now + `mapping_type`: quantization type for weight, controls the weight quantization is symmetric or asymmetric + `act_mapping_type`: quantization type for activation, controls the activation quantization is symmetric or asymmetric + """ + + group_size: int = 32 + layout: Layout = PlainLayout() + mapping_type: MappingType = MappingType.SYMMETRIC + act_mapping_type: MappingType = MappingType.ASYMMETRIC + + +# for BC +int8_dynamic_activation_int4_weight = Int8DynamicActivationInt4WeightConfig + + +@register_quantize_module_handler(Int8DynamicActivationInt4WeightConfig) +def _int8_dynamic_activation_int4_weight_transform( + module: torch.nn.Module, config: Int8DynamicActivationInt4WeightConfig ): - """This is defined here instead of local function to support serialization""" + group_size = config.group_size + layout = config.layout + mapping_type = config.mapping_type + act_mapping_type = config.act_mapping_type + + weight = module.weight + if group_size is None or group_size == -1: group_size = weight.shape[-1] if weight.shape[-1] % group_size != 0: - return weight + return module # weight settings block_size = (1, group_size) @@ -639,41 +667,39 @@ def apply_int8_dynamic_activation_int4_weight_quant( _layout=layout, ) weight = to_linear_activation_quantized(weight, input_quant_func) - return weight + module.weight = torch.nn.Parameter(weight, requires_grad=False) + module.extra_repr = types.MethodType(_linear_extra_repr, module) + return module -def int8_dynamic_activation_int4_weight( - group_size=32, - layout=PlainLayout(), - mapping_type=MappingType.SYMMETRIC, - act_mapping_type=MappingType.ASYMMETRIC, -): - """Applies int8 dynamic per token asymmetric activation quantization and int4 per group weight symmetric quantization to linear - This is used to produce a model for executorch backend, but currently executorch did not - support lowering for the quantized model from this flow yet +@dataclass +class Int4DynamicActivationInt4WeightConfig(AOBaseConfig): + """Applies int4 dynamic per token symmetric activation quantization and int4 per row weight symmetric quantization to linear Args: - `group_size`: parameter for quantization, controls the granularity of quantization, smaller - size is more fine grained `layout`: layout type for quantized weight tensor, only supports `MarlinQQQLayout()` and `CutlassInt4PackedLayout()` for now `mapping_type`: quantization type for weight, controls the weight quantization is symmetric or asymmetric `act_mapping_type`: quantization type for activation, controls the activation quantization is symmetric or asymmetric """ - return _get_linear_subclass_inserter( - apply_int8_dynamic_activation_int4_weight_quant, - group_size=group_size, - layout=layout, - mapping_type=mapping_type, - act_mapping_type=act_mapping_type, - ) + layout: Layout = CutlassInt4PackedLayout() + mapping_type: MappingType = MappingType.SYMMETRIC + act_mapping_type: MappingType = MappingType.SYMMETRIC + + +# for bc +int4_dynamic_activation_int4_weight = Int4DynamicActivationInt4WeightConfig + + +@register_quantize_module_handler(Int4DynamicActivationInt4WeightConfig) +def _int4_dynamic_activation_int4_weight_transform( + module: torch.nn.Module, config: Int4DynamicActivationInt4WeightConfig +) -> torch.nn.Module: + weight = module.weight + layout = config.layout + mapping_type = config.mapping_type + act_mapping_type = config.act_mapping_type -def apply_int4_dynamic_activation_int4_weight_quant( - weight: torch.Tensor, - layout=CutlassInt4PackedLayout(), - mapping_type=MappingType.SYMMETRIC, - act_mapping_type=MappingType.SYMMETRIC, -): if not isinstance(layout, CutlassInt4PackedLayout): raise NotImplementedError( f"Only CutlassInt4PackedLayout layout is supported. Received {layout}." @@ -698,27 +724,9 @@ def apply_int4_dynamic_activation_int4_weight_quant( weight, _int4_symm_per_token_quant_cutlass, ) - return weight - - -def int4_dynamic_activation_int4_weight( - layout=CutlassInt4PackedLayout(), - mapping_type=MappingType.SYMMETRIC, - act_mapping_type=MappingType.SYMMETRIC, -): - """Applies int4 dynamic per token symmetric activation quantization and int4 per row weight symmetric quantization to linear - - Args: - `layout`: layout type for quantized weight tensor, only supports `MarlinQQQLayout()` and `CutlassInt4PackedLayout()` for now - `mapping_type`: quantization type for weight, controls the weight quantization is symmetric or asymmetric - `act_mapping_type`: quantization type for activation, controls the activation quantization is symmetric or asymmetric - """ - return _get_linear_subclass_inserter( - apply_int4_dynamic_activation_int4_weight_quant, - layout=layout, - mapping_type=mapping_type, - act_mapping_type=act_mapping_type, - ) + module.weight = torch.nn.Parameter(weight, requires_grad=False) + module.extra_repr = types.MethodType(_linear_extra_repr, module) + return module def gemlite_uintx_weight_only( @@ -857,29 +865,42 @@ def _int4_weight_only_transform( return module -def int8_weight_only(group_size=None): +@dataclass +class Int8WeightOnlyConfig(AOBaseConfig): """ - Applies int8 weight-only symmetric per-channel quantization to linear layers. + Configuration for applying int8 weight-only symmetric per-channel quantization to linear layers. """ - def apply_int8wo_quant(weight, group_size=None): - mapping_type = MappingType.SYMMETRIC - target_dtype = torch.int8 - eps = torch.finfo(torch.float32).eps - zero_point_dtype = torch.int64 - if group_size is None: - group_size = weight.shape[1] - block_size = (1, group_size) - return to_affine_quantized_intx( - weight, - mapping_type, - block_size, - target_dtype, - eps=eps, - zero_point_dtype=zero_point_dtype, - ) + group_size: Optional[int] = None + + +# for BC +int8_weight_only = Int8WeightOnlyConfig + + +@register_quantize_module_handler(Int8WeightOnlyConfig) +def _int8_weight_only_transform(module: torch.nn.Module, config: Int8WeightOnlyConfig): + group_size = config.group_size + weight = module.weight - return _get_linear_subclass_inserter(apply_int8wo_quant, group_size=group_size) + mapping_type = MappingType.SYMMETRIC + target_dtype = torch.int8 + eps = torch.finfo(torch.float32).eps + zero_point_dtype = torch.int64 + if group_size is None: + group_size = weight.shape[1] + block_size = (1, group_size) + new_weight = to_affine_quantized_intx( + weight, + mapping_type, + block_size, + target_dtype, + eps=eps, + zero_point_dtype=zero_point_dtype, + ) + module.weight = torch.nn.Parameter(new_weight, requires_grad=False) + module.extra_repr = types.MethodType(_linear_extra_repr, module) + return module def _int8_symm_per_token_reduced_range_quant(x: torch.Tensor) -> torch.Tensor: @@ -958,63 +979,76 @@ def _int4_symm_per_token_quant_cutlass(x: torch.Tensor) -> torch.Tensor: ) -def int8_dynamic_activation_int8_weight( - layout=PlainLayout(), - act_mapping_type=MappingType.SYMMETRIC, - weight_only_decode=False, -): +@dataclass +class Int8DynamicActivationInt8WeightConfig(AOBaseConfig): """ - Applies int8 dynamic symmetric per-token activation and int8 per-channel weight + Configuration for applying int8 dynamic symmetric per-token activation and int8 per-channel weight quantization to linear layers """ - def apply_int8_dynamic_activation_int8_weight_quant(weight): - in_features = weight.shape[1] - # int8 dynamic quantization only has benefit when in_feature > 16 - if in_features <= 16: - logger.info( - f"Skipping applying int8_dynamic_activation_int8_weight to weight of shape {weight.shape}" - f" because `in_feature` is <= 16: {in_features}" - ) - return weight + layout: Optional[Layout] = PlainLayout() + act_mapping_type: Optional[MappingType] = MappingType.SYMMETRIC + weight_only_decode: bool = False - # weight settings - mapping_type = MappingType.SYMMETRIC - weight_zero_point_domain = ZeroPointDomain.NONE - def get_weight_block_size(x): - return (1, x.shape[1]) +# for BC +int8_dynamic_activation_int8_weight = Int8DynamicActivationInt8WeightConfig - target_dtype = torch.int8 - eps = torch.finfo(torch.float32).eps - zero_point_dtype = torch.int64 - if weight_only_decode: - input_quant_func = _int8_symm_per_token_reduced_range_quant_noop_decode - else: - # input settings - if act_mapping_type == MappingType.SYMMETRIC: - input_quant_func = _int8_symm_per_token_reduced_range_quant - else: - input_quant_func = _int8_asymm_per_token_quant +@register_quantize_module_handler(Int8DynamicActivationInt8WeightConfig) +def _int8_dynamic_activation_int8_weight_transform( + module: torch.nn.Module, config: Int8DynamicActivationInt8WeightConfig +) -> torch.nn.Module: + layout = config.layout + act_mapping_type = config.act_mapping_type + weight_only_decode = config.weight_only_decode - block_size = get_weight_block_size(weight) - weight = to_affine_quantized_intx( - weight, - mapping_type, - block_size, - target_dtype, - eps=eps, - zero_point_dtype=zero_point_dtype, - _layout=layout, - zero_point_domain=weight_zero_point_domain, + weight = module.weight + + in_features = weight.shape[1] + # int8 dynamic quantization only has benefit when in_feature > 16 + if in_features <= 16: + logger.info( + f"Skipping applying int8_dynamic_activation_int8_weight to weight of shape {weight.shape}" + f" because `in_feature` is <= 16: {in_features}" ) - weight = to_linear_activation_quantized(weight, input_quant_func) - return weight + return module + + # weight settings + mapping_type = MappingType.SYMMETRIC + weight_zero_point_domain = ZeroPointDomain.NONE - return _get_linear_subclass_inserter( - apply_int8_dynamic_activation_int8_weight_quant + def get_weight_block_size(x): + return (1, x.shape[1]) + + target_dtype = torch.int8 + eps = torch.finfo(torch.float32).eps + zero_point_dtype = torch.int64 + + if weight_only_decode: + input_quant_func = _int8_symm_per_token_reduced_range_quant_noop_decode + else: + # input settings + if act_mapping_type == MappingType.SYMMETRIC: + input_quant_func = _int8_symm_per_token_reduced_range_quant + else: + input_quant_func = _int8_asymm_per_token_quant + + block_size = get_weight_block_size(weight) + weight = to_affine_quantized_intx( + weight, + mapping_type, + block_size, + target_dtype, + eps=eps, + zero_point_dtype=zero_point_dtype, + _layout=layout, + zero_point_domain=weight_zero_point_domain, ) + weight = to_linear_activation_quantized(weight, input_quant_func) + module.weight = torch.nn.Parameter(weight, requires_grad=False) + module.extra_repr = types.MethodType(_linear_extra_repr, module) + return module def int8_dynamic_activation_int8_semi_sparse_weight():