diff --git a/test/dtypes/test_affine_quantized.py b/test/dtypes/test_affine_quantized.py index 52b25dab82..8cc0b961a8 100644 --- a/test/dtypes/test_affine_quantized.py +++ b/test/dtypes/test_affine_quantized.py @@ -8,6 +8,7 @@ run_tests, ) +from torchao.core.config import AOBaseConfig from torchao.dtypes import CutlassInt4PackedLayout, Int4CPULayout, SemiSparseLayout from torchao.quantization import ( float8_weight_only, @@ -16,6 +17,7 @@ int8_dynamic_activation_int4_weight, int8_dynamic_activation_int8_weight, int8_weight_only, + quantize_, ) from torchao.quantization.quant_primitives import MappingType, ZeroPointDomain from torchao.utils import ( @@ -82,7 +84,8 @@ def test_tensor_core_layout_transpose(self): t = linear.weight shape = t.shape apply_int4_weight_only_quant = int4_weight_only(group_size=32) - ql = apply_int4_weight_only_quant(linear) + quantize_(linear, apply_int4_weight_only_quant) + ql = linear aqt = ql.weight aqt_shape = aqt.shape self.assertEqual(aqt_shape, shape) @@ -102,7 +105,11 @@ def test_tensor_core_layout_transpose(self): ) def test_weights_only(self, apply_quant): linear = torch.nn.Linear(128, 256, dtype=torch.bfloat16, device="cuda") - ql = apply_quant(linear) + if isinstance(apply_quant, AOBaseConfig): + quantize_(linear, apply_quant) + ql = linear + else: + ql = apply_quant(linear) with tempfile.NamedTemporaryFile() as f: torch.save(ql.state_dict(), f) f.seek(0) @@ -180,8 +187,13 @@ def apply_uint6_weight_only_quant(linear): ) @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") def test_print_quantized_module(self, apply_quant): + print(apply_quant) linear = torch.nn.Linear(128, 256, dtype=torch.bfloat16, device="cuda") - ql = apply_quant(linear) + if isinstance(apply_quant, AOBaseConfig): + quantize_(linear, apply_quant) + ql = linear + else: + ql = apply_quant(linear) assert "AffineQuantizedTensor" in str(ql) @@ -195,7 +207,10 @@ def test_flatten_unflatten(self, device, dtype): apply_quant_list = get_quantization_functions(False, True, device) for apply_quant in apply_quant_list: linear = torch.nn.Linear(128, 256, dtype=dtype, device=device) - ql = apply_quant(linear) + if isinstance(apply_quant, AOBaseConfig): + quantize_(linear, apply_quant) + else: + ql = apply_quant(linear) lp_tensor = ql.weight tensor_data_name_dict, tensor_attributes = lp_tensor.__tensor_flatten__() tensor_data_dict = { diff --git a/test/hqq/test_hqq_affine.py b/test/hqq/test_hqq_affine.py index 381886d594..096c9d26ba 100644 --- a/test/hqq/test_hqq_affine.py +++ b/test/hqq/test_hqq_affine.py @@ -6,6 +6,7 @@ MappingType, ZeroPointDomain, int4_weight_only, + quantize_, uintx_weight_only, ) from torchao.utils import ( @@ -51,9 +52,9 @@ def _eval_hqq(dtype): ) dummy_linear.weight.data = W if dtype == torch.uint4: - q_tensor_hqq = int4_weight_only(group_size=max(block_size), use_hqq=True)( - dummy_linear - ).weight + config = int4_weight_only(group_size=max(block_size), use_hqq=True) + quantize_(dummy_linear, config) + q_tensor_hqq = dummy_linear.weight else: q_tensor_hqq = uintx_weight_only( dtype, group_size=max(block_size), use_hqq=True diff --git a/test/quantization/test_qat.py b/test/quantization/test_qat.py index 8a78b8b387..82324394a8 100644 --- a/test/quantization/test_qat.py +++ b/test/quantization/test_qat.py @@ -1185,7 +1185,7 @@ def test_qat_prototype_bc(self): @unittest.skipIf( not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower" ) - def test_quantize_api(self): + def test_quantize_api_standalone(self): """ Test that the following: diff --git a/test/quantization/test_quant_api.py b/test/quantization/test_quant_api.py index caba1cf31f..acd9b50c5a 100644 --- a/test/quantization/test_quant_api.py +++ b/test/quantization/test_quant_api.py @@ -40,6 +40,7 @@ Int4WeightOnlyQuantizedLinearWeight, Int8WeightOnlyQuantizedLinearWeight, ) +from torchao.quantization.utils import compute_error from torchao.utils import ( TORCH_VERSION_AT_LEAST_2_3, TORCH_VERSION_AT_LEAST_2_4, @@ -783,6 +784,30 @@ def test_int4wo_cpu(self, dtype, x_dim): assert "_weight_int4pack_mm_for_cpu" in code[0] assert "aten.mm.default" not in code[0] + @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "Test only enabled for 2.4+") + @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") + def test_int4_weight_only_numerics(self): + """ + Simple test of e2e int4_weight_only workflow, comparing numerics + to a bfloat16 baseline. + """ + # set up inputs + x = torch.randn(128, 128, device="cuda", dtype=torch.bfloat16) + # TODO(future): model in float32 leads to error: https://gist.github.com/vkuzo/63b3bcd7818393021a6e3fb4ccf3c469 + # is that expected? + m_ref = torch.nn.Sequential(torch.nn.Linear(128, 128)).cuda().bfloat16() + m_int4_wo = copy.deepcopy(m_ref) + + # quantize + quantize_(m_int4_wo, int4_weight_only()) + + with torch.no_grad(): + y_ref = m_ref(x) + y_int4_wo = m_int4_wo(x) + + sqnr = compute_error(y_ref, y_int4_wo) + assert sqnr >= 20, f"SQNR {sqnr} is too low" + class TestMultiTensorFlow(TestCase): @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "Test only enabled for 2.4+") diff --git a/torchao/core/__init__.py b/torchao/core/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/torchao/core/config.py b/torchao/core/config.py new file mode 100644 index 0000000000..a91209b8ed --- /dev/null +++ b/torchao/core/config.py @@ -0,0 +1,10 @@ +import abc + + +class AOBaseConfig(abc.ABC): + """ + If a workflow config inherits from this then `quantize_` knows + how to a apply it to a model. + """ + + pass diff --git a/torchao/quantization/__init__.py b/torchao/quantization/__init__.py index aa4a51d497..b68ab8a179 100644 --- a/torchao/quantization/__init__.py +++ b/torchao/quantization/__init__.py @@ -4,6 +4,7 @@ # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. + from torchao.kernel import ( int_scaled_matmul, safe_int_mm, @@ -85,6 +86,7 @@ swap_linear_with_smooth_fq_linear, ) from .subclass import * # noqa: F403 +from .transform_module import register_quantize_module_handler from .unified import Quantizer, TwoStepQuantizer from .utils import ( compute_error, @@ -144,6 +146,8 @@ # operators/kernels "safe_int_mm", "int_scaled_matmul", + # registration of module transforms for quantize_ + "register_quantize_module_handler", # dataclasses and types "MappingType", "ZeroPointDomain", diff --git a/torchao/quantization/qat/api.py b/torchao/quantization/qat/api.py index 925a0eed3c..ab1b270c50 100644 --- a/torchao/quantization/qat/api.py +++ b/torchao/quantization/qat/api.py @@ -5,10 +5,11 @@ # LICENSE file in the root directory of this source tree. from dataclasses import dataclass -from typing import Any, Callable, List, Optional, Union +from typing import Any, List, Optional, Union import torch +from torchao.core.config import AOBaseConfig from torchao.quantization.granularity import ( Granularity, PerAxis, @@ -22,6 +23,9 @@ TorchAODType, ZeroPointDomain, ) +from torchao.quantization.transform_module import ( + register_quantize_module_handler, +) from torchao.quantization.unified import TwoStepQuantizer @@ -241,12 +245,26 @@ def __setattr__(self, name: str, value: Any): super().__setattr__(name, value) -def intx_quantization_aware_training( - activation_config: Optional[FakeQuantizeConfig] = None, - weight_config: Optional[FakeQuantizeConfig] = None, -) -> Callable: +@dataclass +class IntXQuantizationAwareTrainingConfig(AOBaseConfig): + activation_config: Optional[FakeQuantizeConfig] = None + weight_config: Optional[FakeQuantizeConfig] = None + + +# for BC +intx_quantization_aware_training = IntXQuantizationAwareTrainingConfig + + +@register_quantize_module_handler(IntXQuantizationAwareTrainingConfig) +def _intx_quantization_aware_training_transform( + module: torch.nn.Module, + config: IntXQuantizationAwareTrainingConfig, +) -> torch.nn.Module: """ - Return a function that applies fake quantization to a `torch.nn.Module`. + THIS IS NOT A PUBLIC API - any usage of this outside of torchao + can break at any time. + + Apply fake quantization to a `torch.nn.Module`. to be used with :func:`~torchao.quantization.quant_api.quantize_`. Example usage:: @@ -269,37 +287,32 @@ def intx_quantization_aware_training( `torch.nn.Embedding` with an activation config, then we will raise ValueError as these are not supported. """ - - def _insert_fake_quantize(mod: torch.nn.Module): - """ - Swap the given module with its corresponding fake quantized version. - """ - from .embedding import FakeQuantizedEmbedding - from .linear import FakeQuantizedLinear - - if isinstance(mod, torch.nn.Linear): - return FakeQuantizedLinear.from_linear( - mod, - activation_config, - weight_config, - ) - elif isinstance(mod, torch.nn.Embedding): - if activation_config is not None: - raise ValueError( - "Activation fake quantization is not supported for embedding" - ) - return FakeQuantizedEmbedding.from_embedding(mod, weight_config) - else: + from .embedding import FakeQuantizedEmbedding + from .linear import FakeQuantizedLinear + + mod = module + activation_config = config.activation_config + weight_config = config.weight_config + + if isinstance(mod, torch.nn.Linear): + return FakeQuantizedLinear.from_linear( + mod, + activation_config, + weight_config, + ) + elif isinstance(mod, torch.nn.Embedding): + if activation_config is not None: raise ValueError( - "Module of type '%s' does not have QAT support" % type(mod) + "Activation fake quantization is not supported for embedding" ) + return FakeQuantizedEmbedding.from_embedding(mod, weight_config) + else: + raise ValueError("Module of type '%s' does not have QAT support" % type(mod)) - return _insert_fake_quantize - -def from_intx_quantization_aware_training() -> Callable: +class FromIntXQuantizationAwareTrainingConfig(AOBaseConfig): """ - Return a function that converts a model with fake quantized modules, + Object that knows how to convert a model with fake quantized modules, such as :func:`~torchao.quantization.qat.linear.FakeQuantizedLinear` and :func:`~torchao.quantization.qat.linear.FakeQuantizedEmbedding`, back to model with the original, corresponding modules without @@ -315,22 +328,31 @@ def from_intx_quantization_aware_training() -> Callable: ) """ - def _remove_fake_quantize(mod: torch.nn.Module): - """ - If the given module is a fake quantized module, return the original - corresponding version of the module without fake quantization. - """ - from .embedding import FakeQuantizedEmbedding - from .linear import FakeQuantizedLinear + pass + + +# for BC +from_intx_quantization_aware_training = FromIntXQuantizationAwareTrainingConfig - if isinstance(mod, FakeQuantizedLinear): - return mod.to_linear() - elif isinstance(mod, FakeQuantizedEmbedding): - return mod.to_embedding() - else: - return mod - return _remove_fake_quantize +@register_quantize_module_handler(FromIntXQuantizationAwareTrainingConfig) +def _from_intx_quantization_aware_training_transform( + mod: torch.nn.Module, + config: FromIntXQuantizationAwareTrainingConfig, +) -> torch.nn.Module: + """ + If the given module is a fake quantized module, return the original + corresponding version of the module without fake quantization. + """ + from .embedding import FakeQuantizedEmbedding + from .linear import FakeQuantizedLinear + + if isinstance(mod, FakeQuantizedLinear): + return mod.to_linear() + elif isinstance(mod, FakeQuantizedEmbedding): + return mod.to_embedding() + else: + return mod class ComposableQATQuantizer(TwoStepQuantizer): diff --git a/torchao/quantization/quant_api.py b/torchao/quantization/quant_api.py index 9b7999449f..a6e8ee8e0b 100644 --- a/torchao/quantization/quant_api.py +++ b/torchao/quantization/quant_api.py @@ -18,13 +18,15 @@ import logging import types import warnings -from typing import Callable, Optional, Tuple, Union +from dataclasses import dataclass +from typing import Any, Callable, Optional, Tuple, Union import torch import torch.nn as nn import torch.nn.utils.parametrize as parametrize import torchao +from torchao.core.config import AOBaseConfig from torchao.dtypes import ( AffineQuantizedTensor, CutlassInt4PackedLayout, @@ -47,6 +49,10 @@ LinearActivationWeightObservedTensor, ) from torchao.quantization.observer import AffineQuantizedObserverBase, get_block_size +from torchao.quantization.transform_module import ( + _QUANTIZE_CONFIG_HANDLER, + register_quantize_module_handler, +) from torchao.quantization.weight_tensor_linear_activation_quantization import ( to_weight_tensor_with_linear_activation_quantization_metadata, ) @@ -117,7 +123,6 @@ "Int8DynActInt4WeightGPTQQuantizer", ] -# update according to the support matrix LAYOUT_TO_ZERO_POINT_DOMAIN = { TensorCoreTiledLayout: [ZeroPointDomain.FLOAT], MarlinSparseLayout: [ZeroPointDomain.INT], @@ -228,6 +233,7 @@ def _replace_with_custom_fn_if_matches_filter( filter_fn, cur_fqn="", device=None, + extra_args: Optional[Tuple[Any, ...]] = (), ) -> None: """ Recursively replaces each child module in `model` with the result of `replacement_fn(child)` @@ -239,6 +245,7 @@ def _replace_with_custom_fn_if_matches_filter( filter_fn (Callable[[torch.nn.Module], bool]): The filter function to determine which modules to replace. cur_fqn (str, optional): The current fully qualified name of the module being processed. Defaults to "". device (device, optional): Device to move the model to before applying `filter_fn`. Defaults to None. + extra_args (Tuple[Any, ...], optional): optional extra args to pass to `replacement_fn`. Returns: None @@ -252,12 +259,18 @@ def _replace_with_custom_fn_if_matches_filter( if filter_fn(model, cur_fqn[:-1]): if device is not None: model.to(device=device) # move to device before quantization - model = replacement_fn(model) + model = replacement_fn(model, *extra_args) return model else: - for name, child in model.named_children(): + named_children_list = list(model.named_children()) + for name, child in named_children_list: new_child = _replace_with_custom_fn_if_matches_filter( - child, replacement_fn, filter_fn, f"{cur_fqn}{name}.", device + child, + replacement_fn, + filter_fn, + f"{cur_fqn}{name}.", + device, + extra_args, ) if new_child is not child: setattr(model, name, new_child) @@ -472,17 +485,17 @@ def insert_subclass(lin): def quantize_( model: torch.nn.Module, - apply_tensor_subclass: Callable[[torch.nn.Module], torch.nn.Module], + config: Union[AOBaseConfig, Callable[[torch.nn.Module], torch.nn.Module]], filter_fn: Optional[Callable[[torch.nn.Module, str], bool]] = None, set_inductor_config: bool = True, device: Optional[torch.types.Device] = None, ): - """Convert the weight of linear modules in the model with `apply_tensor_subclass`, model is modified inplace + """Convert the weight of linear modules in the model with `config`, model is modified inplace Args: model (torch.nn.Module): input model - apply_tensor_subclass (Callable[[torch.nn.Module], torch.nn.Module]): function that applies tensor subclass conversion to the weight of a module and return the module (e.g. convert the weight tensor of linear to affine quantized tensor) - filter_fn (Optional[Callable[[torch.nn.Module, str], bool]]): function that takes a nn.Module instance and fully qualified name of the module, returns True if we want to run `apply_tensor_subclass` on + config (Union[AOBaseConfig, Callable[[torch.nn.Module], torch.nn.Module]]): either (1) a workflow configuration object or (2) a function that applies tensor subclass conversion to the weight of a module and return the module (e.g. convert the weight tensor of linear to affine quantized tensor). Note: (2) will be deleted in a future release. + filter_fn (Optional[Callable[[torch.nn.Module, str], bool]]): function that takes a nn.Module instance and fully qualified name of the module, returns True if we want to run `config` on the weight of the module set_inductor_config (bool, optional): Whether to automatically use recommended inductor config settings (defaults to True) device (device, optional): Device to move module to before applying `filter_fn`. This can be set to `"cuda"` to speed up quantization. The final model will be on the specified `device`. @@ -494,7 +507,7 @@ def quantize_( import torch.nn as nn from torchao import quantize_ - # 1. quantize with some predefined `apply_tensor_subclass` method that corresponds to + # quantize with some predefined `config` method that corresponds to # optimized execution paths or kernels (e.g. int4 tinygemm kernel) # also customizable with arguments # currently options are @@ -507,39 +520,36 @@ def quantize_( m = nn.Sequential(nn.Linear(32, 1024), nn.Linear(1024, 32)) quantize_(m, int4_weight_only(group_size=32)) - # 2. write your own new apply_tensor_subclass - # You can also add your own apply_tensor_subclass by manually calling tensor subclass constructor - # on weight - - from torchao.dtypes import to_affine_quantized_intx - - # weight only uint4 asymmetric groupwise quantization - groupsize = 32 - apply_weight_quant = lambda x: to_affine_quantized_intx( - x, "asymmetric", (1, groupsize), torch.int32, 0, 15, 1e-6, - zero_point_dtype=torch.bfloat16, preserve_zero=False, zero_point_domain="float") - - def apply_weight_quant_to_linear(linear): - linear.weight = torch.nn.Parameter(apply_weight_quant(linear.weight), requires_grad=False) - return linear - - # apply to modules under block0 submodule - def filter_fn(module: nn.Module, fqn: str) -> bool: - return isinstance(module, nn.Linear) - - m = nn.Sequential(nn.Linear(32, 1024), nn.Linear(1024, 32)) - quantize_(m, apply_weight_quant_to_linear, filter_fn) - """ if set_inductor_config: torchao.quantization.utils.recommended_inductor_config_setter() - _replace_with_custom_fn_if_matches_filter( - model, - apply_tensor_subclass, - _is_linear if filter_fn is None else filter_fn, - device=device, - ) + if isinstance(config, AOBaseConfig): + handler = _QUANTIZE_CONFIG_HANDLER[type(config)] + # for each linear in the model, apply the transform if filtering passes + _replace_with_custom_fn_if_matches_filter( + model, + handler, + _is_linear if filter_fn is None else filter_fn, + device=device, + extra_args=(config,), + ) + + else: + # old behavior, keep to avoid breaking BC + warnings.warn( + """Passing a generic Callable to `quantize_` is no longer recommended and will be deprecated at a later release. Please see https://github.com/pytorch/ao/pull/1595 for instructions on how to pass in workflow configuration instead.""" + ) + + # make the variable name make sense + apply_tensor_subclass = config + + _replace_with_custom_fn_if_matches_filter( + model, + apply_tensor_subclass, + _is_linear if filter_fn is None else filter_fn, + device=device, + ) def _int8_asymm_per_token_quant(x: torch.Tensor) -> torch.Tensor: @@ -741,14 +751,10 @@ def gemlite_uintx_weight_only( return _get_linear_subclass_inserter(apply_fn) -def int4_weight_only( - group_size=128, - layout=TensorCoreTiledLayout(inner_k_tiles=8), - use_hqq=False, - zero_point_domain=ZeroPointDomain.NONE, -): +@dataclass +class Int4WeightOnlyConfig(AOBaseConfig): """ - Applies uint4 weight-only asymmetric per-group quantization to linear layers, using + Configuration for applying uint4 weight-only asymmetric per-group quantization to linear layers, using "tensor_core_tiled" layout for speedup with tinygemm kernel Note: @@ -765,64 +771,90 @@ def int4_weight_only( size is more fine grained, choices are [256, 128, 64, 32] `layout`: layout type for quantized tensor, default is `TensorCoreTiledLayout(inner_k_tiles=8)` `use_hqq`: whether to use hqq or default quantization mode, default is False - `zero_point_domain`: data type of zeros points, choices are [None(then the value is determined by the layout), ZeroPointDomain.FLOAT, ZeroPointDomain.INT, ZeroPointDomain.NONE] + `zero_point_domain`: data type of zeros points, choices are [ZeroPointDomain.FLOAT, ZeroPointDomain.INT, ZeroPointDomain.NONE] """ - def apply_int4_weight_only_quant(weight): - if weight.shape[-1] % group_size != 0: - logger.info( - f"Skipping quantizing weight with int4 weight only quantization because the shape of weight {weight.shape} is not compatible with group_size {group_size}" - ) - return weight + group_size: int = 128 + layout: Optional[TensorCoreTiledLayout] = TensorCoreTiledLayout(inner_k_tiles=8) + use_hqq: bool = False + zero_point_domain: Optional[ZeroPointDomain] = ZeroPointDomain.NONE - mapping_type = MappingType.ASYMMETRIC - block_size = (1, group_size) - target_dtype = torch.int32 - quant_min = 0 - quant_max = 15 - eps = 1e-6 - preserve_zero = LAYOUT_TO_PRESERVE_ZEROS[type(layout)] - zero_point_dtype = ( - weight.dtype if isinstance(layout, Int4CPULayout) else torch.bfloat16 + +# for BC +# TODO maybe change other callsites +int4_weight_only = Int4WeightOnlyConfig + + +@register_quantize_module_handler(Int4WeightOnlyConfig) +def _int4_weight_only_transform( + module: torch.nn.Module, config: Int4WeightOnlyConfig +) -> torch.nn.Module: + # TODO(future PR): perhaps move this logic to a different file, to keep the API + # file clean of implementation details + + # for now, make these local variables to allow the rest of the function + # to be a direct copy-paste + weight = module.weight + group_size = config.group_size + layout = config.layout + use_hqq = config.use_hqq + zero_point_domain = config.zero_point_domain + + if weight.shape[-1] % group_size != 0: + logger.info( + f"Skipping quantizing weight with int4 weight only quantization because the shape of weight {weight.shape} is not compatible with group_size {group_size}" ) + return module + + mapping_type = MappingType.ASYMMETRIC + block_size = (1, group_size) + target_dtype = torch.int32 + quant_min = 0 + quant_max = 15 + eps = 1e-6 + preserve_zero = LAYOUT_TO_PRESERVE_ZEROS[type(layout)] + zero_point_dtype = ( + weight.dtype if isinstance(layout, Int4CPULayout) else torch.bfloat16 + ) - nonlocal zero_point_domain + # nonlocal zero_point_domain + assert ( + type(layout) in LAYOUT_TO_ZERO_POINT_DOMAIN.keys() + ), f"Only support layout: {LAYOUT_TO_ZERO_POINT_DOMAIN.keys()}" + if zero_point_domain == ZeroPointDomain.NONE: + # the first value is the default one + zero_point_domain = LAYOUT_TO_ZERO_POINT_DOMAIN[type(layout)][0] + else: assert ( - type(layout) in LAYOUT_TO_ZERO_POINT_DOMAIN.keys() - ), f"Only support layout: {LAYOUT_TO_ZERO_POINT_DOMAIN.keys()}" - if zero_point_domain == ZeroPointDomain.NONE: - # the first value is the default one - zero_point_domain = LAYOUT_TO_ZERO_POINT_DOMAIN[type(layout)][0] - else: - assert ( - zero_point_domain in LAYOUT_TO_ZERO_POINT_DOMAIN[type(layout)] - ), f"Layout only support {LAYOUT_TO_ZERO_POINT_DOMAIN[layout]}" - - # Sparse Marlin only supports symmetric quantization. - # NOTE: If we start having lots of layouts that require different configurations, - # we should consider moving this logic somewhere else. - if isinstance(layout, MarlinSparseLayout): - mapping_type = MappingType.SYMMETRIC - assert ( - group_size == 128 or group_size == weight.shape[-1] - ), f"MarlinSparseLayout only supports 128 group size or per channel quantization, got {group_size}" + zero_point_domain in LAYOUT_TO_ZERO_POINT_DOMAIN[type(layout)] + ), f"Layout only support {LAYOUT_TO_ZERO_POINT_DOMAIN[layout]}" - return to_affine_quantized_intx( - weight, - mapping_type, - block_size, - target_dtype, - quant_min, - quant_max, - eps, - zero_point_dtype=zero_point_dtype, - preserve_zero=preserve_zero, - zero_point_domain=zero_point_domain, - _layout=layout, - use_hqq=use_hqq, - ) + # Sparse Marlin only supports symmetric quantization. + # NOTE: If we start having lots of layouts that require different configurations, + # we should consider moving this logic somewhere else. + if isinstance(layout, MarlinSparseLayout): + mapping_type = MappingType.SYMMETRIC + assert ( + group_size == 128 or group_size == weight.shape[-1] + ), f"MarlinSparseLayout only supports 128 group size or per channel quantization, got {group_size}" - return _get_linear_subclass_inserter(apply_int4_weight_only_quant) + new_weight = to_affine_quantized_intx( + weight, + mapping_type, + block_size, + target_dtype, + quant_min, + quant_max, + eps, + zero_point_dtype=zero_point_dtype, + preserve_zero=preserve_zero, + zero_point_domain=zero_point_domain, + _layout=layout, + use_hqq=use_hqq, + ) + module.weight = torch.nn.Parameter(new_weight, requires_grad=False) + module.extra_repr = types.MethodType(_linear_extra_repr, module) + return module def int8_weight_only(group_size=None): diff --git a/torchao/quantization/transform_module.py b/torchao/quantization/transform_module.py new file mode 100644 index 0000000000..1b1f20394b --- /dev/null +++ b/torchao/quantization/transform_module.py @@ -0,0 +1,19 @@ +import functools +from typing import Callable, Dict + +import torch + +from torchao.core.config import AOBaseConfig + +_QUANTIZE_CONFIG_HANDLER: Dict[ + AOBaseConfig, + Callable[[torch.nn.Module, AOBaseConfig], torch.nn.Module], +] = {} + + +def register_quantize_module_handler(config_type): + @functools.wraps(config_type) + def decorator(func): + _QUANTIZE_CONFIG_HANDLER[config_type] = func + + return decorator