From 24114cebb3fd77737185b1e30bef050283c51478 Mon Sep 17 00:00:00 2001 From: vasiliy Date: Wed, 22 Jan 2025 08:49:11 -0800 Subject: [PATCH 01/23] Update [ghstack-poisoned] --- test/quantization/test_qat.py | 2 +- test/quantization/test_quant_api.py | 26 +++ torchao/core/__init__.py | 0 torchao/core/config.py | 13 ++ torchao/quantization/_transform_module.py | 17 ++ torchao/quantization/qat/api.py | 114 +++++++------ torchao/quantization/quant_api.py | 191 ++++++++++++++-------- 7 files changed, 249 insertions(+), 114 deletions(-) create mode 100644 torchao/core/__init__.py create mode 100644 torchao/core/config.py create mode 100644 torchao/quantization/_transform_module.py diff --git a/test/quantization/test_qat.py b/test/quantization/test_qat.py index 8a78b8b387..82324394a8 100644 --- a/test/quantization/test_qat.py +++ b/test/quantization/test_qat.py @@ -1185,7 +1185,7 @@ def test_qat_prototype_bc(self): @unittest.skipIf( not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower" ) - def test_quantize_api(self): + def test_quantize_api_standalone(self): """ Test that the following: diff --git a/test/quantization/test_quant_api.py b/test/quantization/test_quant_api.py index 177c357047..ca2cbf08ec 100644 --- a/test/quantization/test_quant_api.py +++ b/test/quantization/test_quant_api.py @@ -40,6 +40,7 @@ Int4WeightOnlyQuantizedLinearWeight, Int8WeightOnlyQuantizedLinearWeight, ) +from torchao.quantization.utils import compute_error from torchao.utils import ( TORCH_VERSION_AT_LEAST_2_3, TORCH_VERSION_AT_LEAST_2_4, @@ -761,6 +762,31 @@ def reset_memory(): assert param.is_cuda self.assertLess(memory_streaming, memory_baseline) + def test_int4_weight_only_numerics(self): + """ + Simple test of e2e int4_weight_only workflow, comparing numerics + to a bfloat16 baseline. + """ + # TODO(before land) skip on cpu-only + # TODO(before land) support other inference techniques? + + # set up inputs + x = torch.randn(128, 128, device="cuda", dtype=torch.bfloat16) + # TODO: model in float32 leads to error: https://gist.github.com/vkuzo/63b3bcd7818393021a6e3fb4ccf3c469 + # is that expected? + m_ref = torch.nn.Sequential(torch.nn.Linear(128, 128)).cuda().bfloat16() + m_int4_wo = copy.deepcopy(m_ref) + + # quantize + quantize_(m_int4_wo, int4_weight_only()) + + with torch.no_grad(): + y_ref = m_ref(x) + y_int4_wo = m_int4_wo(x) + + sqnr = compute_error(y_ref, y_int4_wo) + assert sqnr >= 20, f"SQNR {sqnr} is too low" + class TestMultiTensorFlow(TestCase): @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "Test only enabled for 2.4+") diff --git a/torchao/core/__init__.py b/torchao/core/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/torchao/core/config.py b/torchao/core/config.py new file mode 100644 index 0000000000..fbc1216212 --- /dev/null +++ b/torchao/core/config.py @@ -0,0 +1,13 @@ +import abc + + +# directory location for this might need more polish +class AOBaseWorkflowConfig(abc.ABC): + """ + If a workflow config inherits from this then `quantize_` knows + what to do with it. + + TODO write a better docblock. + """ + + pass diff --git a/torchao/quantization/_transform_module.py b/torchao/quantization/_transform_module.py new file mode 100644 index 0000000000..f14e79b5a9 --- /dev/null +++ b/torchao/quantization/_transform_module.py @@ -0,0 +1,17 @@ +from typing import Callable, Dict + +import torch + +from torchao.core.config import AOBaseWorkflowConfig + +_QUANTIZE_CONFIG_HANDLER: Dict[ + AOBaseWorkflowConfig, + Callable[[torch.nn.Module, AOBaseWorkflowConfig], torch.nn.Module], +] = {} + + +def register_quantize_module_handler(config_type): + def decorator(func): + _QUANTIZE_CONFIG_HANDLER[config_type] = func + + return decorator diff --git a/torchao/quantization/qat/api.py b/torchao/quantization/qat/api.py index cd3813291f..6356ee1600 100644 --- a/torchao/quantization/qat/api.py +++ b/torchao/quantization/qat/api.py @@ -5,10 +5,14 @@ # LICENSE file in the root directory of this source tree. from dataclasses import dataclass -from typing import Any, Callable, List, Optional, Union +from typing import Any, List, Optional, Union import torch +from torchao.core.config import AOBaseWorkflowConfig +from torchao.quantization._transform_module import ( + register_quantize_module_handler, +) from torchao.quantization.granularity import ( Granularity, PerAxis, @@ -239,12 +243,26 @@ def __setattr__(self, name: str, value: Any): super().__setattr__(name, value) -def intx_quantization_aware_training( - activation_config: Optional[FakeQuantizeConfig] = None, - weight_config: Optional[FakeQuantizeConfig] = None, -) -> Callable: +@dataclass +class IntXQuantizationAwareTrainingConfig(AOBaseWorkflowConfig): + activation_config: Optional[FakeQuantizeConfig] = None + weight_config: Optional[FakeQuantizeConfig] = None + + +# for BC +intx_quantization_aware_training = IntXQuantizationAwareTrainingConfig + + +@register_quantize_module_handler(IntXQuantizationAwareTrainingConfig) +def _intx_quantization_aware_training_transform( + module: torch.nn.Module, + config: IntXQuantizationAwareTrainingConfig, +) -> torch.nn.Module: """ - Return a function that applies fake quantization to a `torch.nn.Module`. + THIS IS NOT A PUBLIC API - any usage of this outside of torchao + can break at any time. + + Apply fake quantization to a `torch.nn.Module`. to be used with :func:`~torchao.quantization.quant_api.quantize_`. Example usage:: @@ -267,37 +285,32 @@ def intx_quantization_aware_training( `torch.nn.Embedding` with an activation config, then we will raise ValueError as these are not supported. """ - - def _insert_fake_quantize(mod: torch.nn.Module): - """ - Swap the given module with its corresponding fake quantized version. - """ - from .embedding import FakeQuantizedEmbedding - from .linear import FakeQuantizedLinear - - if isinstance(mod, torch.nn.Linear): - return FakeQuantizedLinear.from_linear( - mod, - activation_config, - weight_config, - ) - elif isinstance(mod, torch.nn.Embedding): - if activation_config is not None: - raise ValueError( - "Activation fake quantization is not supported for embedding" - ) - return FakeQuantizedEmbedding.from_embedding(mod, weight_config) - else: + from .embedding import FakeQuantizedEmbedding + from .linear import FakeQuantizedLinear + + mod = module + activation_config = config.activation_config + weight_config = config.weight_config + + if isinstance(mod, torch.nn.Linear): + return FakeQuantizedLinear.from_linear( + mod, + activation_config, + weight_config, + ) + elif isinstance(mod, torch.nn.Embedding): + if activation_config is not None: raise ValueError( - "Module of type '%s' does not have QAT support" % type(mod) + "Activation fake quantization is not supported for embedding" ) + return FakeQuantizedEmbedding.from_embedding(mod, weight_config) + else: + raise ValueError("Module of type '%s' does not have QAT support" % type(mod)) - return _insert_fake_quantize - -def from_intx_quantization_aware_training() -> Callable: +class FromIntXQuantizationAwareTrainingConfig(AOBaseWorkflowConfig): """ - Return a function that converts a model with fake quantized modules, + Object that knows how to convert a model with fake quantized modules, such as :func:`~torchao.quantization.qat.linear.FakeQuantizedLinear` and :func:`~torchao.quantization.qat.linear.FakeQuantizedEmbedding`, back to model with the original, corresponding modules without @@ -313,22 +326,31 @@ def from_intx_quantization_aware_training() -> Callable: ) """ - def _remove_fake_quantize(mod: torch.nn.Module): - """ - If the given module is a fake quantized module, return the original - corresponding version of the module without fake quantization. - """ - from .embedding import FakeQuantizedEmbedding - from .linear import FakeQuantizedLinear + pass + + +# for BC +from_intx_quantization_aware_training = FromIntXQuantizationAwareTrainingConfig - if isinstance(mod, FakeQuantizedLinear): - return mod.to_linear() - elif isinstance(mod, FakeQuantizedEmbedding): - return mod.to_embedding() - else: - return mod - return _remove_fake_quantize +@register_quantize_module_handler(FromIntXQuantizationAwareTrainingConfig) +def _from_intx_quantization_aware_training_transform( + mod: torch.nn.Module, + config: FromIntXQuantizationAwareTrainingConfig, +) -> torch.nn.Module: + """ + If the given module is a fake quantized module, return the original + corresponding version of the module without fake quantization. + """ + from .embedding import FakeQuantizedEmbedding + from .linear import FakeQuantizedLinear + + if isinstance(mod, FakeQuantizedLinear): + return mod.to_linear() + elif isinstance(mod, FakeQuantizedEmbedding): + return mod.to_embedding() + else: + return mod class ComposableQATQuantizer(TwoStepQuantizer): diff --git a/torchao/quantization/quant_api.py b/torchao/quantization/quant_api.py index b2eff196fd..450563be36 100644 --- a/torchao/quantization/quant_api.py +++ b/torchao/quantization/quant_api.py @@ -18,13 +18,15 @@ import logging import types import warnings -from typing import Callable, Optional, Tuple, Union +from dataclasses import dataclass +from typing import Any, Callable, Optional, Tuple, Union import torch import torch.nn as nn import torch.nn.utils.parametrize as parametrize import torchao +from torchao.core.config import AOBaseWorkflowConfig from torchao.dtypes import ( AffineQuantizedTensor, CutlassInt4PackedLayout, @@ -43,6 +45,10 @@ ) from torchao.float8.float8_linear import Float8Linear from torchao.float8.inference import Float8MMConfig +from torchao.quantization._transform_module import ( + _QUANTIZE_CONFIG_HANDLER, + register_quantize_module_handler, +) from torchao.quantization.linear_activation_weight_observed_tensor import ( LinearActivationWeightObservedTensor, ) @@ -117,7 +123,6 @@ "Int8DynActInt4WeightGPTQQuantizer", ] -# update according to the support matrix LAYOUT_TO_ZERO_POINT_DOMAIN = { TensorCoreTiledLayout: [ZeroPointDomain.FLOAT], MarlinSparseLayout: [ZeroPointDomain.INT], @@ -228,6 +233,7 @@ def _replace_with_custom_fn_if_matches_filter( filter_fn, cur_fqn="", device=None, + extra_args: Optional[Tuple[Any, ...]] = (), ) -> None: """ Recursively replaces each child module in `model` with the result of `replacement_fn(child)` @@ -239,6 +245,7 @@ def _replace_with_custom_fn_if_matches_filter( filter_fn (Callable[[torch.nn.Module], bool]): The filter function to determine which modules to replace. cur_fqn (str, optional): The current fully qualified name of the module being processed. Defaults to "". device (device, optional): Device to move the model to before applying `filter_fn`. Defaults to None. + extra_args (Tuple[Any, ...], optional): optional extra args to pass to `replacement_fn`. Returns: None @@ -252,12 +259,17 @@ def _replace_with_custom_fn_if_matches_filter( if filter_fn(model, cur_fqn[:-1]): if device is not None: model.to(device=device) # move to device before quantization - model = replacement_fn(model) + model = replacement_fn(model, *extra_args) return model else: for name, child in model.named_children(): new_child = _replace_with_custom_fn_if_matches_filter( - child, replacement_fn, filter_fn, f"{cur_fqn}{name}.", device + child, + replacement_fn, + filter_fn, + f"{cur_fqn}{name}.", + device, + extra_args, ) if new_child is not child: setattr(model, name, new_child) @@ -468,7 +480,10 @@ def insert_subclass(lin): def quantize_( model: torch.nn.Module, - apply_tensor_subclass: Callable[[torch.nn.Module], torch.nn.Module], + # apply_tensor_subclass: Callable[[torch.nn.Module], torch.nn.Module], + apply_tensor_subclass: Union[ + Callable[[torch.nn.Module], torch.nn.Module], AOBaseWorkflowConfig + ], filter_fn: Optional[Callable[[torch.nn.Module, str], bool]] = None, set_inductor_config: bool = True, device: Optional[torch.types.Device] = None, @@ -530,12 +545,33 @@ def filter_fn(module: nn.Module, fqn: str) -> bool: if set_inductor_config: torchao.quantization.utils.recommended_inductor_config_setter() - _replace_with_custom_fn_if_matches_filter( - model, - apply_tensor_subclass, - _is_linear if filter_fn is None else filter_fn, - device=device, - ) + if isinstance(apply_tensor_subclass, AOBaseWorkflowConfig): + # new behavior + + # make the variable name make sense + config = apply_tensor_subclass + handler = _QUANTIZE_CONFIG_HANDLER[type(config)] + + # for each linear in the model, apply the transform if filtering passes + # key difference from old is that `config_with_transform` is easily + # inspectable + _replace_with_custom_fn_if_matches_filter( + model, + handler, + _is_linear if filter_fn is None else filter_fn, + device=device, + extra_args=(config,), + ) + + else: + # old behavior, for now keep for BC purposes + # TODO(after discussion): flesh the BC story out more + _replace_with_custom_fn_if_matches_filter( + model, + apply_tensor_subclass, + _is_linear if filter_fn is None else filter_fn, + device=device, + ) def _int8_asymm_per_token_quant(x: torch.Tensor) -> torch.Tensor: @@ -684,14 +720,10 @@ def gemlite_uintx_weight_only( return _get_linear_subclass_inserter(apply_fn) -def int4_weight_only( - group_size=128, - layout=TensorCoreTiledLayout(inner_k_tiles=8), - use_hqq=False, - zero_point_domain=None, -): +@dataclass +class Int4WeightOnlyConfig(AOBaseWorkflowConfig): """ - Applies uint4 weight-only asymmetric per-group quantization to linear layers, using + Configuration for applying uint4 weight-only asymmetric per-group quantization to linear layers, using "tensor_core_tiled" layout for speedup with tinygemm kernel Note: @@ -711,59 +743,84 @@ def int4_weight_only( `zero_point_domain`: data type of zeros points, choices are [None(then the value is determined by the layout), ZeroPointDomain.FLOAT, ZeroPointDomain.INT, ZeroPointDomain.NONE] """ - def apply_int4_weight_only_quant(weight): - if weight.shape[-1] % group_size != 0: - logger.info( - f"Skipping quantizing weight with int4 weight only quantization because the shape of weight {weight.shape} is not compatible with group_size {group_size}" - ) - return weight + group_size: int = 128 + layout: Optional[TensorCoreTiledLayout] = TensorCoreTiledLayout(inner_k_tiles=8) + use_hqq: bool = False + zero_point_domain: Optional[ZeroPointDomain] = None - mapping_type = MappingType.ASYMMETRIC - block_size = (1, group_size) - target_dtype = torch.int32 - quant_min = 0 - quant_max = 15 - eps = 1e-6 - preserve_zero = LAYOUT_TO_PRESERVE_ZEROS[type(layout)] - zero_point_dtype = torch.bfloat16 - - nonlocal zero_point_domain - assert ( - type(layout) in LAYOUT_TO_ZERO_POINT_DOMAIN.keys() - ), f"Only support layout: {LAYOUT_TO_ZERO_POINT_DOMAIN.keys()}" - if zero_point_domain is None: - # the first value is the default one - zero_point_domain = LAYOUT_TO_ZERO_POINT_DOMAIN[type(layout)][0] - else: - assert ( - zero_point_domain in LAYOUT_TO_ZERO_POINT_DOMAIN[type(layout)] - ), f"Layout only support {LAYOUT_TO_ZERO_POINT_DOMAIN[layout]}" - - # Sparse Marlin only supports symmetric quantization. - # NOTE: If we start having lots of layouts that require different configurations, - # we should consider moving this logic somewhere else. - if isinstance(layout, MarlinSparseLayout): - mapping_type = MappingType.SYMMETRIC - assert ( - group_size == 128 or group_size == weight.shape[-1] - ), f"MarlinSparseLayout only supports 128 group size or per channel quantization, got {group_size}" - return to_affine_quantized_intx( - weight, - mapping_type, - block_size, - target_dtype, - quant_min, - quant_max, - eps, - zero_point_dtype=zero_point_dtype, - preserve_zero=preserve_zero, - zero_point_domain=zero_point_domain, - _layout=layout, - use_hqq=use_hqq, +# for BC +# TODO maybe change other callsites +int4_weight_only = Int4WeightOnlyConfig + + +@register_quantize_module_handler(Int4WeightOnlyConfig) +def _int4_weight_only_transform( + module: torch.nn.Module, config: Int4WeightOnlyConfig +) -> torch.nn.Module: + # TODO(future PR): perhaps move this logic to a different file, to keep the API + # file clean of implementation details + + # for now, make these local variables to allow the rest of the function + # to be a direct copy-paste + weight = module.weight + group_size = config.group_size + layout = config.layout + use_hqq = config.use_hqq + zero_point_domain = config.zero_point_domain + + if weight.shape[-1] % group_size != 0: + logger.info( + f"Skipping quantizing weight with int4 weight only quantization because the shape of weight {weight.shape} is not compatible with group_size {group_size}" ) + return weight - return _get_linear_subclass_inserter(apply_int4_weight_only_quant) + mapping_type = MappingType.ASYMMETRIC + block_size = (1, group_size) + target_dtype = torch.int32 + quant_min = 0 + quant_max = 15 + eps = 1e-6 + preserve_zero = LAYOUT_TO_PRESERVE_ZEROS[type(layout)] + zero_point_dtype = torch.bfloat16 + + # nonlocal zero_point_domain + assert ( + type(layout) in LAYOUT_TO_ZERO_POINT_DOMAIN.keys() + ), f"Only support layout: {LAYOUT_TO_ZERO_POINT_DOMAIN.keys()}" + if zero_point_domain is None: + # the first value is the default one + zero_point_domain = LAYOUT_TO_ZERO_POINT_DOMAIN[type(layout)][0] + else: + assert ( + zero_point_domain in LAYOUT_TO_ZERO_POINT_DOMAIN[type(layout)] + ), f"Layout only support {LAYOUT_TO_ZERO_POINT_DOMAIN[layout]}" + + # Sparse Marlin only supports symmetric quantization. + # NOTE: If we start having lots of layouts that require different configurations, + # we should consider moving this logic somewhere else. + if isinstance(layout, MarlinSparseLayout): + mapping_type = MappingType.SYMMETRIC + assert ( + group_size == 128 or group_size == weight.shape[-1] + ), f"MarlinSparseLayout only supports 128 group size or per channel quantization, got {group_size}" + + new_weight = to_affine_quantized_intx( + weight, + mapping_type, + block_size, + target_dtype, + quant_min, + quant_max, + eps, + zero_point_dtype=zero_point_dtype, + preserve_zero=preserve_zero, + zero_point_domain=zero_point_domain, + _layout=layout, + use_hqq=use_hqq, + ) + module.weight = torch.nn.Parameter(new_weight) + return module def int8_weight_only(group_size=None): From 5b9d876d7ea41db7964278c6b59b27e6b79645fb Mon Sep 17 00:00:00 2001 From: vasiliy Date: Wed, 22 Jan 2025 10:08:28 -0800 Subject: [PATCH 02/23] Update [ghstack-poisoned] --- test/dtypes/test_affine_quantized.py | 7 +++- test/quantization/test_quant_api.py | 7 ++-- torchao/quantization/quant_api.py | 58 ++++++++-------------------- 3 files changed, 25 insertions(+), 47 deletions(-) diff --git a/test/dtypes/test_affine_quantized.py b/test/dtypes/test_affine_quantized.py index f08ba7aa72..1b4bf58cf9 100644 --- a/test/dtypes/test_affine_quantized.py +++ b/test/dtypes/test_affine_quantized.py @@ -8,6 +8,7 @@ run_tests, ) +from torchao.core.config import AOBaseWorkflowConfig from torchao.dtypes import CutlassInt4PackedLayout, Int4CPULayout, SemiSparseLayout from torchao.quantization import ( float8_weight_only, @@ -15,6 +16,7 @@ int8_dynamic_activation_int4_weight, int8_dynamic_activation_int8_weight, int8_weight_only, + quantize_, ) from torchao.quantization.quant_primitives import MappingType, ZeroPointDomain from torchao.utils import ( @@ -186,7 +188,10 @@ def test_flatten_unflatten(self, device, dtype): apply_quant_list = get_quantization_functions(False, True, device) for apply_quant in apply_quant_list: linear = torch.nn.Linear(128, 256, dtype=dtype, device=device) - ql = apply_quant(linear) + if isinstance(apply_quant, AOBaseWorkflowConfig): + quantize_(linear, apply_quant) + else: + ql = apply_quant(linear) lp_tensor = ql.weight tensor_data_name_dict, tensor_attributes = lp_tensor.__tensor_flatten__() tensor_data_dict = { diff --git a/test/quantization/test_quant_api.py b/test/quantization/test_quant_api.py index ca2cbf08ec..80536bfac9 100644 --- a/test/quantization/test_quant_api.py +++ b/test/quantization/test_quant_api.py @@ -762,17 +762,16 @@ def reset_memory(): assert param.is_cuda self.assertLess(memory_streaming, memory_baseline) + @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "Test only enabled for 2.4+") + @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") def test_int4_weight_only_numerics(self): """ Simple test of e2e int4_weight_only workflow, comparing numerics to a bfloat16 baseline. """ - # TODO(before land) skip on cpu-only - # TODO(before land) support other inference techniques? - # set up inputs x = torch.randn(128, 128, device="cuda", dtype=torch.bfloat16) - # TODO: model in float32 leads to error: https://gist.github.com/vkuzo/63b3bcd7818393021a6e3fb4ccf3c469 + # TODO(future): model in float32 leads to error: https://gist.github.com/vkuzo/63b3bcd7818393021a6e3fb4ccf3c469 # is that expected? m_ref = torch.nn.Sequential(torch.nn.Linear(128, 128)).cuda().bfloat16() m_int4_wo = copy.deepcopy(m_ref) diff --git a/torchao/quantization/quant_api.py b/torchao/quantization/quant_api.py index 450563be36..e36bc7d8e3 100644 --- a/torchao/quantization/quant_api.py +++ b/torchao/quantization/quant_api.py @@ -262,7 +262,8 @@ def _replace_with_custom_fn_if_matches_filter( model = replacement_fn(model, *extra_args) return model else: - for name, child in model.named_children(): + named_children_list = list(model.named_children()) + for name, child in named_children_list: new_child = _replace_with_custom_fn_if_matches_filter( child, replacement_fn, @@ -480,20 +481,19 @@ def insert_subclass(lin): def quantize_( model: torch.nn.Module, - # apply_tensor_subclass: Callable[[torch.nn.Module], torch.nn.Module], - apply_tensor_subclass: Union[ - Callable[[torch.nn.Module], torch.nn.Module], AOBaseWorkflowConfig + config: Union[ + AOBaseWorkflowConfig, Callable[[torch.nn.Module], torch.nn.Module] ], filter_fn: Optional[Callable[[torch.nn.Module, str], bool]] = None, set_inductor_config: bool = True, device: Optional[torch.types.Device] = None, ): - """Convert the weight of linear modules in the model with `apply_tensor_subclass`, model is modified inplace + """Convert the weight of linear modules in the model with `config`, model is modified inplace Args: model (torch.nn.Module): input model - apply_tensor_subclass (Callable[[torch.nn.Module], torch.nn.Module]): function that applies tensor subclass conversion to the weight of a module and return the module (e.g. convert the weight tensor of linear to affine quantized tensor) - filter_fn (Optional[Callable[[torch.nn.Module, str], bool]]): function that takes a nn.Module instance and fully qualified name of the module, returns True if we want to run `apply_tensor_subclass` on + config (Union[AOBaseWorkflowConfig, Callable[[torch.nn.Module], torch.nn.Module]]): either (1) a workflow configuration object or (2) a function that applies tensor subclass conversion to the weight of a module and return the module (e.g. convert the weight tensor of linear to affine quantized tensor). Note: (2) will be deleted in a future release. + filter_fn (Optional[Callable[[torch.nn.Module, str], bool]]): function that takes a nn.Module instance and fully qualified name of the module, returns True if we want to run `config` on the weight of the module set_inductor_config (bool, optional): Whether to automatically use recommended inductor config settings (defaults to True) device (device, optional): Device to move module to before applying `filter_fn`. This can be set to `"cuda"` to speed up quantization. The final model will be on the specified `device`. @@ -505,7 +505,7 @@ def quantize_( import torch.nn as nn from torchao import quantize_ - # 1. quantize with some predefined `apply_tensor_subclass` method that corresponds to + # quantize with some predefined `config` method that corresponds to # optimized execution paths or kernels (e.g. int4 tinygemm kernel) # also customizable with arguments # currently options are @@ -518,43 +518,13 @@ def quantize_( m = nn.Sequential(nn.Linear(32, 1024), nn.Linear(1024, 32)) quantize_(m, int4_weight_only(group_size=32)) - # 2. write your own new apply_tensor_subclass - # You can also add your own apply_tensor_subclass by manually calling tensor subclass constructor - # on weight - - from torchao.dtypes import to_affine_quantized_intx - - # weight only uint4 asymmetric groupwise quantization - groupsize = 32 - apply_weight_quant = lambda x: to_affine_quantized_intx( - x, "asymmetric", (1, groupsize), torch.int32, 0, 15, 1e-6, - zero_point_dtype=torch.bfloat16, preserve_zero=False, zero_point_domain="float") - - def apply_weight_quant_to_linear(linear): - linear.weight = torch.nn.Parameter(apply_weight_quant(linear.weight), requires_grad=False) - return linear - - # apply to modules under block0 submodule - def filter_fn(module: nn.Module, fqn: str) -> bool: - return isinstance(module, nn.Linear) - - m = nn.Sequential(nn.Linear(32, 1024), nn.Linear(1024, 32)) - quantize_(m, apply_weight_quant_to_linear, filter_fn) - """ if set_inductor_config: torchao.quantization.utils.recommended_inductor_config_setter() - if isinstance(apply_tensor_subclass, AOBaseWorkflowConfig): - # new behavior - - # make the variable name make sense - config = apply_tensor_subclass + if isinstance(config, AOBaseWorkflowConfig): handler = _QUANTIZE_CONFIG_HANDLER[type(config)] - # for each linear in the model, apply the transform if filtering passes - # key difference from old is that `config_with_transform` is easily - # inspectable _replace_with_custom_fn_if_matches_filter( model, handler, @@ -564,8 +534,12 @@ def filter_fn(module: nn.Module, fqn: str) -> bool: ) else: - # old behavior, for now keep for BC purposes - # TODO(after discussion): flesh the BC story out more + # old behavior, keep to avoid breaking BC + warnings.warn("""Passing a generic Callable to `quantize_` is no longer recommended and will be deprecated at a later release. Please see https://github.com/pytorch/ao/pull/1595 for instructions on how to pass in workflow configuration instead.""") + + # make the variable name make sense + apply_tensor_subclass = config + _replace_with_custom_fn_if_matches_filter( model, apply_tensor_subclass, @@ -773,7 +747,7 @@ def _int4_weight_only_transform( logger.info( f"Skipping quantizing weight with int4 weight only quantization because the shape of weight {weight.shape} is not compatible with group_size {group_size}" ) - return weight + return module mapping_type = MappingType.ASYMMETRIC block_size = (1, group_size) From 1cea42fbd49f534c697471f9c35c424768607985 Mon Sep 17 00:00:00 2001 From: vasiliy Date: Wed, 22 Jan 2025 10:39:15 -0800 Subject: [PATCH 03/23] Update [ghstack-poisoned] --- test/dtypes/test_affine_quantized.py | 2 +- torchao/quantization/quant_api.py | 8 ++++---- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/test/dtypes/test_affine_quantized.py b/test/dtypes/test_affine_quantized.py index 1b4bf58cf9..9ef26026e2 100644 --- a/test/dtypes/test_affine_quantized.py +++ b/test/dtypes/test_affine_quantized.py @@ -189,7 +189,7 @@ def test_flatten_unflatten(self, device, dtype): for apply_quant in apply_quant_list: linear = torch.nn.Linear(128, 256, dtype=dtype, device=device) if isinstance(apply_quant, AOBaseWorkflowConfig): - quantize_(linear, apply_quant) + quantize_(linear, apply_quant) else: ql = apply_quant(linear) lp_tensor = ql.weight diff --git a/torchao/quantization/quant_api.py b/torchao/quantization/quant_api.py index e36bc7d8e3..efda1dbb23 100644 --- a/torchao/quantization/quant_api.py +++ b/torchao/quantization/quant_api.py @@ -481,9 +481,7 @@ def insert_subclass(lin): def quantize_( model: torch.nn.Module, - config: Union[ - AOBaseWorkflowConfig, Callable[[torch.nn.Module], torch.nn.Module] - ], + config: Union[AOBaseWorkflowConfig, Callable[[torch.nn.Module], torch.nn.Module]], filter_fn: Optional[Callable[[torch.nn.Module, str], bool]] = None, set_inductor_config: bool = True, device: Optional[torch.types.Device] = None, @@ -535,7 +533,9 @@ def quantize_( else: # old behavior, keep to avoid breaking BC - warnings.warn("""Passing a generic Callable to `quantize_` is no longer recommended and will be deprecated at a later release. Please see https://github.com/pytorch/ao/pull/1595 for instructions on how to pass in workflow configuration instead.""") + warnings.warn( + """Passing a generic Callable to `quantize_` is no longer recommended and will be deprecated at a later release. Please see https://github.com/pytorch/ao/pull/1595 for instructions on how to pass in workflow configuration instead.""" + ) # make the variable name make sense apply_tensor_subclass = config From 138883b4f40073517c1a5a71dd87c00d33c87d43 Mon Sep 17 00:00:00 2001 From: vasiliy Date: Wed, 22 Jan 2025 12:44:06 -0800 Subject: [PATCH 04/23] Update [ghstack-poisoned] --- test/dtypes/test_affine_quantized.py | 19 +++++++++++++++---- test/hqq/test_hqq_affine.py | 7 ++++--- torchao/quantization/quant_api.py | 1 + 3 files changed, 20 insertions(+), 7 deletions(-) diff --git a/test/dtypes/test_affine_quantized.py b/test/dtypes/test_affine_quantized.py index 9ef26026e2..671c676e76 100644 --- a/test/dtypes/test_affine_quantized.py +++ b/test/dtypes/test_affine_quantized.py @@ -60,7 +60,8 @@ def get_quantization_functions( ) ) - if do_sparse: + # TODO(before land): revert this back, added due to lack of cuSparseLt in my env + if do_sparse and False: base_functions.append( int8_dynamic_activation_int8_weight(layout=SemiSparseLayout()) ) @@ -78,7 +79,8 @@ def test_tensor_core_layout_transpose(self): t = linear.weight shape = t.shape apply_int4_weight_only_quant = int4_weight_only(group_size=32) - ql = apply_int4_weight_only_quant(linear) + quantize_(linear, apply_int4_weight_only_quant) + ql = linear aqt = ql.weight aqt_shape = aqt.shape self.assertEqual(aqt_shape, shape) @@ -97,7 +99,11 @@ def test_tensor_core_layout_transpose(self): ) def test_weights_only(self, apply_quant): linear = torch.nn.Linear(128, 256, dtype=torch.bfloat16, device="cuda") - ql = apply_quant(linear) + if isinstance(apply_quant, AOBaseWorkflowConfig): + quantize_(linear, apply_quant) + ql = linear + else: + ql = apply_quant(linear) with tempfile.NamedTemporaryFile() as f: torch.save(ql.state_dict(), f) f.seek(0) @@ -173,8 +179,13 @@ def apply_uint6_weight_only_quant(linear): @common_utils.parametrize("apply_quant", get_quantization_functions(True, True)) @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") def test_print_quantized_module(self, apply_quant): + print(apply_quant) linear = torch.nn.Linear(128, 256, dtype=torch.bfloat16, device="cuda") - ql = apply_quant(linear) + if isinstance(apply_quant, AOBaseWorkflowConfig): + quantize_(linear, apply_quant) + ql = linear + else: + ql = apply_quant(linear) assert "AffineQuantizedTensor" in str(ql) diff --git a/test/hqq/test_hqq_affine.py b/test/hqq/test_hqq_affine.py index 381886d594..096c9d26ba 100644 --- a/test/hqq/test_hqq_affine.py +++ b/test/hqq/test_hqq_affine.py @@ -6,6 +6,7 @@ MappingType, ZeroPointDomain, int4_weight_only, + quantize_, uintx_weight_only, ) from torchao.utils import ( @@ -51,9 +52,9 @@ def _eval_hqq(dtype): ) dummy_linear.weight.data = W if dtype == torch.uint4: - q_tensor_hqq = int4_weight_only(group_size=max(block_size), use_hqq=True)( - dummy_linear - ).weight + config = int4_weight_only(group_size=max(block_size), use_hqq=True) + quantize_(dummy_linear, config) + q_tensor_hqq = dummy_linear.weight else: q_tensor_hqq = uintx_weight_only( dtype, group_size=max(block_size), use_hqq=True diff --git a/torchao/quantization/quant_api.py b/torchao/quantization/quant_api.py index efda1dbb23..1c7284a01d 100644 --- a/torchao/quantization/quant_api.py +++ b/torchao/quantization/quant_api.py @@ -794,6 +794,7 @@ def _int4_weight_only_transform( use_hqq=use_hqq, ) module.weight = torch.nn.Parameter(new_weight) + module.extra_repr = types.MethodType(_linear_extra_repr, module) return module From ba045ea89316a7a14b92d4849f44e9ff1ad276f5 Mon Sep 17 00:00:00 2001 From: vasiliy Date: Wed, 22 Jan 2025 12:56:28 -0800 Subject: [PATCH 05/23] Update [ghstack-poisoned] --- test/dtypes/test_affine_quantized.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/test/dtypes/test_affine_quantized.py b/test/dtypes/test_affine_quantized.py index 671c676e76..2cb87ab133 100644 --- a/test/dtypes/test_affine_quantized.py +++ b/test/dtypes/test_affine_quantized.py @@ -60,8 +60,7 @@ def get_quantization_functions( ) ) - # TODO(before land): revert this back, added due to lack of cuSparseLt in my env - if do_sparse and False: + if do_sparse: base_functions.append( int8_dynamic_activation_int8_weight(layout=SemiSparseLayout()) ) From 94d942606bcea5bad5c36b819d779deaa7c1572b Mon Sep 17 00:00:00 2001 From: vasiliy Date: Wed, 22 Jan 2025 15:08:47 -0800 Subject: [PATCH 06/23] Update [ghstack-poisoned] --- torchao/quantization/quant_api.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchao/quantization/quant_api.py b/torchao/quantization/quant_api.py index 1c7284a01d..3401a42ab7 100644 --- a/torchao/quantization/quant_api.py +++ b/torchao/quantization/quant_api.py @@ -793,7 +793,7 @@ def _int4_weight_only_transform( _layout=layout, use_hqq=use_hqq, ) - module.weight = torch.nn.Parameter(new_weight) + module.weight = torch.nn.Parameter(new_weight, requires_grad=False) module.extra_repr = types.MethodType(_linear_extra_repr, module) return module From 26850dae92bdcf6535fcf30ca4fc21f4074bde44 Mon Sep 17 00:00:00 2001 From: vasiliy Date: Wed, 5 Feb 2025 13:34:44 -0800 Subject: [PATCH 07/23] Update [ghstack-poisoned] --- torchao/quantization/quant_api.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/torchao/quantization/quant_api.py b/torchao/quantization/quant_api.py index c598393b50..a6e8ee8e0b 100644 --- a/torchao/quantization/quant_api.py +++ b/torchao/quantization/quant_api.py @@ -779,6 +779,7 @@ class Int4WeightOnlyConfig(AOBaseConfig): use_hqq: bool = False zero_point_domain: Optional[ZeroPointDomain] = ZeroPointDomain.NONE + # for BC # TODO maybe change other callsites int4_weight_only = Int4WeightOnlyConfig @@ -812,7 +813,9 @@ def _int4_weight_only_transform( quant_max = 15 eps = 1e-6 preserve_zero = LAYOUT_TO_PRESERVE_ZEROS[type(layout)] - zero_point_dtype = torch.bfloat16 + zero_point_dtype = ( + weight.dtype if isinstance(layout, Int4CPULayout) else torch.bfloat16 + ) # nonlocal zero_point_domain assert ( From d42a59070b95f85983226d280f2de60a7dbf8735 Mon Sep 17 00:00:00 2001 From: vasiliy Date: Mon, 10 Feb 2025 15:33:48 -0800 Subject: [PATCH 08/23] Update [ghstack-poisoned] --- test/dtypes/test_affine_quantized.py | 14 +- test/quantization/test_quant_api.py | 25 ++- torchao/quantization/quant_api.py | 246 ++++++++++++++++----------- 3 files changed, 180 insertions(+), 105 deletions(-) diff --git a/test/dtypes/test_affine_quantized.py b/test/dtypes/test_affine_quantized.py index 53ca470b04..d26f1d8e04 100644 --- a/test/dtypes/test_affine_quantized.py +++ b/test/dtypes/test_affine_quantized.py @@ -123,16 +123,24 @@ def test_weights_only(self, apply_quant): @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") @common_utils.parametrize("apply_quant", get_quantization_functions(False, False)) def test_to_device(self, apply_quant): + def _apply(module, config_or_subclass_inserter): + if isinstance(config_or_subclass_inserter, AOBaseConfig): + quantize_(module, config_or_subclass_inserter) + else: + # TODO(#1690): delete this once config migration is done + module = config_or_subclass_inserter(module) + return module + linear = torch.nn.Linear(128, 256, dtype=torch.bfloat16) - ql = apply_quant(linear) + ql = _apply(linear, apply_quant) ql.to("cuda") linear = torch.nn.Linear(128, 256, dtype=torch.bfloat16) - ql = apply_quant(linear) + ql = _apply(linear, apply_quant) ql.to(device="cuda") linear = torch.nn.Linear(128, 256, dtype=torch.bfloat16) - ql = apply_quant(linear) + ql = _apply(linear, apply_quant) ql.cuda() @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") diff --git a/test/quantization/test_quant_api.py b/test/quantization/test_quant_api.py index acd9b50c5a..b9220c2815 100644 --- a/test/quantization/test_quant_api.py +++ b/test/quantization/test_quant_api.py @@ -30,6 +30,9 @@ Quantizer, TwoStepQuantizer, _replace_with_custom_fn_if_matches_filter, + float8_dynamic_activation_float8_weight, + float8_static_activation_float8_weight, + float8_weight_only, int4_weight_only, int8_dynamic_activation_int4_weight, int8_dynamic_activation_int8_weight, @@ -784,9 +787,21 @@ def test_int4wo_cpu(self, dtype, x_dim): assert "_weight_int4pack_mm_for_cpu" in code[0] assert "aten.mm.default" not in code[0] + # TODO(#1690): move to new config names @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "Test only enabled for 2.4+") @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") - def test_int4_weight_only_numerics(self): + @common_utils.parametrize( + "config", + [ + int4_weight_only(), + float8_weight_only(), + float8_dynamic_activation_float8_weight(), + float8_static_activation_float8_weight( + scale=torch.tensor([1.0], device="cuda") + ), + ], + ) + def test_workflow_e2e_numerics(self, config): """ Simple test of e2e int4_weight_only workflow, comparing numerics to a bfloat16 baseline. @@ -796,16 +811,16 @@ def test_int4_weight_only_numerics(self): # TODO(future): model in float32 leads to error: https://gist.github.com/vkuzo/63b3bcd7818393021a6e3fb4ccf3c469 # is that expected? m_ref = torch.nn.Sequential(torch.nn.Linear(128, 128)).cuda().bfloat16() - m_int4_wo = copy.deepcopy(m_ref) + m_q = copy.deepcopy(m_ref) # quantize - quantize_(m_int4_wo, int4_weight_only()) + quantize_(m_q, config) with torch.no_grad(): y_ref = m_ref(x) - y_int4_wo = m_int4_wo(x) + y_q = m_q(x) - sqnr = compute_error(y_ref, y_int4_wo) + sqnr = compute_error(y_ref, y_q) assert sqnr >= 20, f"SQNR {sqnr} is too low" diff --git a/torchao/quantization/quant_api.py b/torchao/quantization/quant_api.py index a6e8ee8e0b..01e3a7c029 100644 --- a/torchao/quantization/quant_api.py +++ b/torchao/quantization/quant_api.py @@ -1030,30 +1030,43 @@ def int8_dynamic_activation_int8_semi_sparse_weight(): return int8_dynamic_activation_int8_weight(layout=SemiSparseLayout()) -def float8_weight_only(weight_dtype: torch.dtype = torch.float8_e4m3fn): +@dataclass +class Float8WeightOnlyConfig(AOBaseConfig): """ - Applies float8 weight-only symmetric per-channel quantization to linear layers. + Configuration for applying float8 weight-only symmetric per-channel quantization to linear layers. Args: weight_dtype (torch.dtype): The target data type for weight quantization. Default is torch.float8_e4m3fn. Note: The actual matmul will be computed in original precision of the weight tensor. - """ - from torchao.dtypes import to_affine_quantized_floatx - def apply_float8wo_quant(weight): - block_size = (1, weight.shape[1]) - return to_affine_quantized_floatx( - input_float=weight, - block_size=block_size, - target_dtype=weight_dtype, - scale_dtype=None, - _layout=Float8Layout(mm_config=None), - ) + weight_dtype: torch.dtype = torch.float8_e4m3fn + + +# for BC +float8_weight_only = Float8WeightOnlyConfig + + +@register_quantize_module_handler(Float8WeightOnlyConfig) +def _float8_weight_only_transform( + module: torch.nn.Module, config: Float8WeightOnlyConfig +) -> torch.nn.Module: + from torchao.dtypes import to_affine_quantized_floatx - return _get_linear_subclass_inserter(apply_float8wo_quant) + weight = module.weight + block_size = (1, weight.shape[1]) + new_weight = to_affine_quantized_floatx( + input_float=weight, + block_size=block_size, + target_dtype=config.weight_dtype, + scale_dtype=None, + _layout=Float8Layout(mm_config=None), + ) + module.weight = torch.nn.Parameter(new_weight, requires_grad=False) + module.extra_repr = types.MethodType(_linear_extra_repr, module) + return module _fp8_granularities = Union[PerTensor, PerRow] @@ -1170,16 +1183,10 @@ def _fp8_mm_compat(weight: torch.Tensor) -> bool: return is_compatible -def float8_dynamic_activation_float8_weight( - activation_dtype: torch.dtype = torch.float8_e4m3fn, - weight_dtype: torch.dtype = torch.float8_e4m3fn, - granularity: Optional[ - Union[_fp8_granularities, Tuple[_fp8_granularities, _fp8_granularities]] - ] = None, - mm_config: Optional[Float8MMConfig] = None, -): +@dataclass +class Float8DynamicActivationFloat8WeightConfig(AOBaseConfig): """ - Applies float8 dynamic symmetric quantization to both activations and weights of linear layers. + Configuration for applying float8 dynamic symmetric quantization to both activations and weights of linear layers. Args: activation_dtype (torch.dtype): The target data type for activation quantization. Default is torch.float8_e4m3fn. @@ -1192,56 +1199,75 @@ def float8_dynamic_activation_float8_weight( mm_config (Float8MMConfig): Configuration for the matrix multiplication. Default uses fast accumulation. """ - assert ( - is_sm_at_least_89() or is_MI300() - ), "Float8 dynamic activation quantization is only supported on CUDA>=8.9 and MI300+" - if mm_config is None: - mm_config = Float8MMConfig(use_fast_accum=True) - activation_granularity, weight_granularity = _normalize_granularity(granularity) + activation_dtype: torch.dtype = torch.float8_e4m3fn + weight_dtype: torch.dtype = torch.float8_e4m3fn + granularity: Optional[ + Union[_fp8_granularities, Tuple[_fp8_granularities, _fp8_granularities]] + ] = None + mm_config: Optional[Float8MMConfig] = None - def apply_float8_dynamic_activation_quant(weight: torch.Tensor): - if not _fp8_mm_compat(weight): - return weight - if isinstance(weight_granularity, PerRow): - assert ( - weight.dtype == torch.bfloat16 - ), "PerRow quantization only works for bfloat16 precision input weight" + def __post_init__(self): + assert ( + is_sm_at_least_89() or is_MI300() + ), "Float8 dynamic activation quantization is only supported on CUDA>=8.9 and MI300+" + if self.mm_config is None: + self.mm_config = Float8MMConfig(use_fast_accum=True) - block_size = get_block_size(weight.shape, weight_granularity) - quantized_weight = to_affine_quantized_floatx( - input_float=weight, - block_size=block_size, - target_dtype=weight_dtype, - scale_dtype=torch.float32, - _layout=Float8Layout(mm_config=mm_config), - ) - input_quant_func = _input_activation_quant_func_fp8 - input_quant_kwargs = { - "activation_granularity": activation_granularity, - "activation_dtype": activation_dtype, - } +# for bc +float8_dynamic_activation_float8_weight = Float8DynamicActivationFloat8WeightConfig - quantized_weight = to_linear_activation_quantized( - quantized_weight, input_quant_func, quant_kwargs=input_quant_kwargs - ) - return quantized_weight - return _get_linear_subclass_inserter(apply_float8_dynamic_activation_quant) +@register_quantize_module_handler(Float8DynamicActivationFloat8WeightConfig) +def _float8_dynamic_activation_float8_weight_transform( + module: torch.nn.Module, config: Float8DynamicActivationFloat8WeightConfig +): + activation_dtype = config.activation_dtype + weight_dtype = config.weight_dtype + granularity = config.granularity + mm_config = config.mm_config + weight = module.weight + activation_granularity, weight_granularity = _normalize_granularity(granularity) -def float8_static_activation_float8_weight( - scale: torch.Tensor, - activation_dtype: torch.dtype = torch.float8_e4m3fn, - weight_dtype: torch.dtype = torch.float8_e4m3fn, - granularity: Optional[ - Union[_fp8_granularities, Tuple[_fp8_granularities, _fp8_granularities]] - ] = None, - mm_config: Optional[Float8MMConfig] = None, -): + if not _fp8_mm_compat(weight): + # TODO(future PR): this should really throw an exception instead of silently + # not doing what the user asked + return module + if isinstance(weight_granularity, PerRow): + assert ( + weight.dtype == torch.bfloat16 + ), "PerRow quantization only works for bfloat16 precision input weight" + + block_size = get_block_size(weight.shape, weight_granularity) + quantized_weight = to_affine_quantized_floatx( + input_float=weight, + block_size=block_size, + target_dtype=weight_dtype, + scale_dtype=torch.float32, + _layout=Float8Layout(mm_config=mm_config), + ) + + input_quant_func = _input_activation_quant_func_fp8 + input_quant_kwargs = { + "activation_granularity": activation_granularity, + "activation_dtype": activation_dtype, + } + + quantized_weight = to_linear_activation_quantized( + quantized_weight, input_quant_func, quant_kwargs=input_quant_kwargs + ) + + module.weight = torch.nn.Parameter(quantized_weight, requires_grad=False) + module.extra_repr = types.MethodType(_linear_extra_repr, module) + return module + + +@dataclass +class Float8StaticActivationFloat8WeightConfig(AOBaseConfig): """ - Applies float8 static symmetric quantization to + Configuration for applying float8 static symmetric quantization to Args: scale (torch.Tensor): The scale tensor for activation quantization. @@ -1249,47 +1275,73 @@ def float8_static_activation_float8_weight( weight_dtype (torch.dtype): The target data type for weight quantization. Default is torch.float8_e4m mm_config (Float8MMConfig): Configuration for the matrix multiplication. Default uses fast accumulation. """ - assert ( - is_sm_at_least_89() or is_MI300() - ), "Float8 static activation quantization is only supported on CUDA 8.9 and above" - if mm_config is None: - mm_config = Float8MMConfig(use_fast_accum=True) + scale: torch.Tensor + activation_dtype: torch.dtype = torch.float8_e4m3fn + weight_dtype: torch.dtype = torch.float8_e4m3fn + granularity: Optional[ + Union[_fp8_granularities, Tuple[_fp8_granularities, _fp8_granularities]] + ] = None + mm_config: Optional[Float8MMConfig] = None + + def __post_init__(self): + assert ( + is_sm_at_least_89() or is_MI300() + ), "Float8 static activation quantization is only supported on CUDA 8.9 and above" + if self.mm_config is None: + self.mm_config = Float8MMConfig(use_fast_accum=True) + + +# for bc +float8_static_activation_float8_weight = Float8StaticActivationFloat8WeightConfig + + +@register_quantize_module_handler(Float8StaticActivationFloat8WeightConfig) +def _float8_static_activation_float8_weight_transform( + module: torch.nn.Module, config: Float8StaticActivationFloat8WeightConfig +): + scale = config.scale + activation_dtype = config.activation_dtype + weight_dtype = config.weight_dtype + granularity = config.granularity + mm_config = config.mm_config + + weight = module.weight activation_granularity, weight_granularity = _normalize_granularity(granularity) assert isinstance( activation_granularity, PerTensor ), "Static quantization only supports PerTensor granularity" - def apply_float8_static_activation_quant(weight: torch.Tensor): - if not _fp8_mm_compat(weight): - return weight - block_size = get_block_size(weight.shape, weight_granularity) - quantized_weight = to_affine_quantized_floatx( - input_float=weight, - block_size=block_size, - target_dtype=weight_dtype, - scale_dtype=torch.float32, - _layout=Float8Layout(mm_config=mm_config), - ) + if not _fp8_mm_compat(weight): + # TODO(future PR): this should really throw an exception instead of silently + # not doing what the user asked + return module + block_size = get_block_size(weight.shape, weight_granularity) + quantized_weight = to_affine_quantized_floatx( + input_float=weight, + block_size=block_size, + target_dtype=weight_dtype, + scale_dtype=torch.float32, + _layout=Float8Layout(mm_config=mm_config), + ) - input_quant_func = _input_activation_quant_func_fp8 - input_quant_kwargs = { - "activation_granularity": activation_granularity, - "activation_dtype": activation_dtype, - } - - quantized_weight = ( - to_weight_tensor_with_linear_activation_quantization_metadata( - quantized_weight, - input_quant_func, - scale=scale, - zero_point=None, - quant_kwargs=input_quant_kwargs, - ) - ) - return quantized_weight + input_quant_func = _input_activation_quant_func_fp8 + input_quant_kwargs = { + "activation_granularity": activation_granularity, + "activation_dtype": activation_dtype, + } - return _get_linear_subclass_inserter(apply_float8_static_activation_quant) + quantized_weight = to_weight_tensor_with_linear_activation_quantization_metadata( + quantized_weight, + input_quant_func, + scale=scale, + zero_point=None, + quant_kwargs=input_quant_kwargs, + ) + + module.weight = torch.nn.Parameter(quantized_weight, requires_grad=False) + module.extra_repr = types.MethodType(_linear_extra_repr, module) + return module def uintx_weight_only(dtype, group_size=64, pack_dim=-1, use_hqq=False): From 5702ea030a5163cfe53d2b1ff8cf0610b5cea5fc Mon Sep 17 00:00:00 2001 From: vasiliy Date: Mon, 10 Feb 2025 16:51:31 -0800 Subject: [PATCH 09/23] Update [ghstack-poisoned] --- test/quantization/test_quant_api.py | 13 +++++++++++++ torchao/quantization/quant_api.py | 14 ++++++++------ 2 files changed, 21 insertions(+), 6 deletions(-) diff --git a/test/quantization/test_quant_api.py b/test/quantization/test_quant_api.py index b9220c2815..61ea2c5558 100644 --- a/test/quantization/test_quant_api.py +++ b/test/quantization/test_quant_api.py @@ -49,6 +49,7 @@ TORCH_VERSION_AT_LEAST_2_4, TORCH_VERSION_AT_LEAST_2_5, TORCH_VERSION_AT_LEAST_2_6, + is_sm_at_least_89, unwrap_tensor_subclass, ) @@ -806,6 +807,18 @@ def test_workflow_e2e_numerics(self, config): Simple test of e2e int4_weight_only workflow, comparing numerics to a bfloat16 baseline. """ + if ( + isinstance( + config, + ( + float8_dynamic_activation_float8_weight, + float8_static_activation_float8_weight, + ), + ) + and not is_sm_at_least_89() + ): + return unittest.skip("requires CUDA capability 8.9 or greater") + # set up inputs x = torch.randn(128, 128, device="cuda", dtype=torch.bfloat16) # TODO(future): model in float32 leads to error: https://gist.github.com/vkuzo/63b3bcd7818393021a6e3fb4ccf3c469 diff --git a/torchao/quantization/quant_api.py b/torchao/quantization/quant_api.py index 01e3a7c029..12ac02096e 100644 --- a/torchao/quantization/quant_api.py +++ b/torchao/quantization/quant_api.py @@ -1208,9 +1208,6 @@ class Float8DynamicActivationFloat8WeightConfig(AOBaseConfig): mm_config: Optional[Float8MMConfig] = None def __post_init__(self): - assert ( - is_sm_at_least_89() or is_MI300() - ), "Float8 dynamic activation quantization is only supported on CUDA>=8.9 and MI300+" if self.mm_config is None: self.mm_config = Float8MMConfig(use_fast_accum=True) @@ -1223,6 +1220,10 @@ def __post_init__(self): def _float8_dynamic_activation_float8_weight_transform( module: torch.nn.Module, config: Float8DynamicActivationFloat8WeightConfig ): + assert ( + is_sm_at_least_89() or is_MI300() + ), "Float8 dynamic activation quantization is only supported on CUDA>=8.9 and MI300+" + activation_dtype = config.activation_dtype weight_dtype = config.weight_dtype granularity = config.granularity @@ -1285,9 +1286,6 @@ class Float8StaticActivationFloat8WeightConfig(AOBaseConfig): mm_config: Optional[Float8MMConfig] = None def __post_init__(self): - assert ( - is_sm_at_least_89() or is_MI300() - ), "Float8 static activation quantization is only supported on CUDA 8.9 and above" if self.mm_config is None: self.mm_config = Float8MMConfig(use_fast_accum=True) @@ -1300,6 +1298,10 @@ def __post_init__(self): def _float8_static_activation_float8_weight_transform( module: torch.nn.Module, config: Float8StaticActivationFloat8WeightConfig ): + assert ( + is_sm_at_least_89() or is_MI300() + ), "Float8 static activation quantization is only supported on CUDA 8.9 and above" + scale = config.scale activation_dtype = config.activation_dtype weight_dtype = config.weight_dtype From 0542402b263299ac8cc643899f85913a9c037f28 Mon Sep 17 00:00:00 2001 From: vasiliy Date: Mon, 10 Feb 2025 19:10:03 -0800 Subject: [PATCH 10/23] Update [ghstack-poisoned] --- torchao/quantization/__init__.py | 2 ++ torchao/quantization/qat/__init__.py | 4 ++++ 2 files changed, 6 insertions(+) diff --git a/torchao/quantization/__init__.py b/torchao/quantization/__init__.py index b68ab8a179..71e8de337a 100644 --- a/torchao/quantization/__init__.py +++ b/torchao/quantization/__init__.py @@ -46,6 +46,7 @@ AffineQuantizedObserverBase, ) from .quant_api import ( + Int4WeightOnlyConfig, float8_dynamic_activation_float8_weight, float8_static_activation_float8_weight, float8_weight_only, @@ -119,6 +120,7 @@ "fpx_weight_only", "gemlite_uintx_weight_only", "swap_conv2d_1x1_to_linear", + "Int4WeightOnlyConfig", # smooth quant - subject to change "get_scale", "SmoothFakeDynQuantMixin", diff --git a/torchao/quantization/qat/__init__.py b/torchao/quantization/qat/__init__.py index 15008e03ea..5dc3d8e008 100644 --- a/torchao/quantization/qat/__init__.py +++ b/torchao/quantization/qat/__init__.py @@ -1,6 +1,8 @@ from .api import ( ComposableQATQuantizer, FakeQuantizeConfig, + FromIntXQuantizationAwareTrainingConfig, + IntXQuantizationAwareTrainingConfig, from_intx_quantization_aware_training, intx_quantization_aware_training, ) @@ -20,4 +22,6 @@ "Int8DynActInt4WeightQATQuantizer", "intx_quantization_aware_training", "from_intx_quantization_aware_training", + "FromIntXQuantizationAwareTrainingConfig", + "IntXQuantizationAwareTrainingConfig", ] From 5f7589795f99d09f09a99778cb65867854c43e36 Mon Sep 17 00:00:00 2001 From: vasiliy Date: Mon, 10 Feb 2025 20:01:31 -0800 Subject: [PATCH 11/23] Update [ghstack-poisoned] --- test/quantization/test_quant_api.py | 11 ++ torchao/quantization/__init__.py | 8 + torchao/quantization/quant_api.py | 269 ++++++++++++++++------------ 3 files changed, 170 insertions(+), 118 deletions(-) diff --git a/test/quantization/test_quant_api.py b/test/quantization/test_quant_api.py index e0f6cb1ace..b81e928aa6 100644 --- a/test/quantization/test_quant_api.py +++ b/test/quantization/test_quant_api.py @@ -33,6 +33,7 @@ float8_dynamic_activation_float8_weight, float8_static_activation_float8_weight, float8_weight_only, + int4_dynamic_activation_int4_weight, int4_weight_only, int8_dynamic_activation_int4_weight, int8_dynamic_activation_int8_weight, @@ -50,6 +51,7 @@ TORCH_VERSION_AT_LEAST_2_5, TORCH_VERSION_AT_LEAST_2_6, is_sm_at_least_89, + is_sm_at_least_90, unwrap_tensor_subclass, ) @@ -798,6 +800,10 @@ def test_int4wo_cpu(self, dtype, x_dim): float8_weight_only(), float8_dynamic_activation_float8_weight(), float8_static_activation_float8_weight(scale=torch.tensor([1.0])), + int4_dynamic_activation_int4_weight(), + int8_dynamic_activation_int8_weight(), + int8_dynamic_activation_int4_weight(), + int8_weight_only(), ], ) def test_workflow_e2e_numerics(self, config): @@ -816,6 +822,11 @@ def test_workflow_e2e_numerics(self, config): and not is_sm_at_least_89() ): return unittest.skip("requires CUDA capability 8.9 or greater") + elif ( + isinstance(config, int4_dynamic_activation_int4_weight) + and is_sm_at_least_90() + ): + return unittest.skip("only supported on CUDA capability 8.9, not greater") # scale has to be moved to cuda here because the parametrization init # code happens before gating for cuda availability diff --git a/torchao/quantization/__init__.py b/torchao/quantization/__init__.py index ca9a4141fc..a1d8bda058 100644 --- a/torchao/quantization/__init__.py +++ b/torchao/quantization/__init__.py @@ -49,7 +49,11 @@ Float8DynamicActivationFloat8WeightConfig, Float8StaticActivationFloat8WeightConfig, Float8WeightOnlyConfig, + Int4DynamicActivationInt4WeightConfig, Int4WeightOnlyConfig, + Int8DynamicActivationInt4WeightConfig, + Int8DynamicActivationInt8WeightConfig, + Int8WeightOnlyConfig, float8_dynamic_activation_float8_weight, float8_static_activation_float8_weight, float8_weight_only, @@ -123,7 +127,11 @@ "fpx_weight_only", "gemlite_uintx_weight_only", "swap_conv2d_1x1_to_linear", + "Int4DynamicActivationInt4WeightConfig", + "Int8DynamicActivationInt4WeightConfig", + "Int8DynamicActivationInt8WeightConfig", "Int4WeightOnlyConfig", + "Int8WeightOnlyConfig", "Float8WeightOnlyConfig", "Float8DynamicActivationFloat8WeightConfig", "Float8StaticActivationFloat8WeightConfig", diff --git a/torchao/quantization/quant_api.py b/torchao/quantization/quant_api.py index 12ac02096e..6c4ac40dc7 100644 --- a/torchao/quantization/quant_api.py +++ b/torchao/quantization/quant_api.py @@ -43,6 +43,7 @@ to_affine_quantized_intx, to_marlinqqq_quantized_intx, ) +from torchao.dtypes.utils import Layout from torchao.float8.float8_linear import Float8Linear from torchao.float8.inference import Float8MMConfig from torchao.quantization.linear_activation_weight_observed_tensor import ( @@ -590,18 +591,45 @@ def _int8_symm_per_token_quant(x: torch.Tensor) -> torch.Tensor: ) -def apply_int8_dynamic_activation_int4_weight_quant( - weight, - group_size=32, - layout=PlainLayout(), - mapping_type=MappingType.SYMMETRIC, - act_mapping_type=MappingType.ASYMMETRIC, +@dataclass +class Int8DynamicActivationInt4WeightConfig(AOBaseConfig): + """Configuration for applying int8 dynamic per token asymmetric activation quantization and int4 per group weight symmetric quantization to linear + This is used to produce a model for executorch backend, but currently executorch did not + support lowering for the quantized model from this flow yet + + Args: + `group_size`: parameter for quantization, controls the granularity of quantization, smaller + size is more fine grained + `layout`: layout type for quantized weight tensor, only supports `MarlinQQQLayout()` and `CutlassInt4PackedLayout()` for now + `mapping_type`: quantization type for weight, controls the weight quantization is symmetric or asymmetric + `act_mapping_type`: quantization type for activation, controls the activation quantization is symmetric or asymmetric + """ + + group_size: int = 32 + layout: Layout = PlainLayout() + mapping_type: MappingType = MappingType.SYMMETRIC + act_mapping_type: MappingType = MappingType.ASYMMETRIC + + +# for BC +int8_dynamic_activation_int4_weight = Int8DynamicActivationInt4WeightConfig + + +@register_quantize_module_handler(Int8DynamicActivationInt4WeightConfig) +def _int8_dynamic_activation_int4_weight_transform( + module: torch.nn.Module, config: Int8DynamicActivationInt4WeightConfig ): - """This is defined here instead of local function to support serialization""" + group_size = config.group_size + layout = config.layout + mapping_type = config.mapping_type + act_mapping_type = config.act_mapping_type + + weight = module.weight + if group_size is None or group_size == -1: group_size = weight.shape[-1] if weight.shape[-1] % group_size != 0: - return weight + return module # weight settings block_size = (1, group_size) @@ -639,41 +667,39 @@ def apply_int8_dynamic_activation_int4_weight_quant( _layout=layout, ) weight = to_linear_activation_quantized(weight, input_quant_func) - return weight + module.weight = torch.nn.Parameter(weight, requires_grad=False) + module.extra_repr = types.MethodType(_linear_extra_repr, module) + return module -def int8_dynamic_activation_int4_weight( - group_size=32, - layout=PlainLayout(), - mapping_type=MappingType.SYMMETRIC, - act_mapping_type=MappingType.ASYMMETRIC, -): - """Applies int8 dynamic per token asymmetric activation quantization and int4 per group weight symmetric quantization to linear - This is used to produce a model for executorch backend, but currently executorch did not - support lowering for the quantized model from this flow yet +@dataclass +class Int4DynamicActivationInt4WeightConfig(AOBaseConfig): + """Applies int4 dynamic per token symmetric activation quantization and int4 per row weight symmetric quantization to linear Args: - `group_size`: parameter for quantization, controls the granularity of quantization, smaller - size is more fine grained `layout`: layout type for quantized weight tensor, only supports `MarlinQQQLayout()` and `CutlassInt4PackedLayout()` for now `mapping_type`: quantization type for weight, controls the weight quantization is symmetric or asymmetric `act_mapping_type`: quantization type for activation, controls the activation quantization is symmetric or asymmetric """ - return _get_linear_subclass_inserter( - apply_int8_dynamic_activation_int4_weight_quant, - group_size=group_size, - layout=layout, - mapping_type=mapping_type, - act_mapping_type=act_mapping_type, - ) + layout: Layout = CutlassInt4PackedLayout() + mapping_type: MappingType = MappingType.SYMMETRIC + act_mapping_type: MappingType = MappingType.SYMMETRIC + + +# for bc +int4_dynamic_activation_int4_weight = Int4DynamicActivationInt4WeightConfig + + +@register_quantize_module_handler(Int4DynamicActivationInt4WeightConfig) +def _int4_dynamic_activation_int4_weight_transform( + module: torch.nn.Module, config: Int4DynamicActivationInt4WeightConfig +) -> torch.nn.Module: + weight = module.weight + layout = config.layout + mapping_type = config.mapping_type + act_mapping_type = config.act_mapping_type -def apply_int4_dynamic_activation_int4_weight_quant( - weight: torch.Tensor, - layout=CutlassInt4PackedLayout(), - mapping_type=MappingType.SYMMETRIC, - act_mapping_type=MappingType.SYMMETRIC, -): if not isinstance(layout, CutlassInt4PackedLayout): raise NotImplementedError( f"Only CutlassInt4PackedLayout layout is supported. Received {layout}." @@ -698,27 +724,9 @@ def apply_int4_dynamic_activation_int4_weight_quant( weight, _int4_symm_per_token_quant_cutlass, ) - return weight - - -def int4_dynamic_activation_int4_weight( - layout=CutlassInt4PackedLayout(), - mapping_type=MappingType.SYMMETRIC, - act_mapping_type=MappingType.SYMMETRIC, -): - """Applies int4 dynamic per token symmetric activation quantization and int4 per row weight symmetric quantization to linear - - Args: - `layout`: layout type for quantized weight tensor, only supports `MarlinQQQLayout()` and `CutlassInt4PackedLayout()` for now - `mapping_type`: quantization type for weight, controls the weight quantization is symmetric or asymmetric - `act_mapping_type`: quantization type for activation, controls the activation quantization is symmetric or asymmetric - """ - return _get_linear_subclass_inserter( - apply_int4_dynamic_activation_int4_weight_quant, - layout=layout, - mapping_type=mapping_type, - act_mapping_type=act_mapping_type, - ) + module.weight = torch.nn.Parameter(weight, requires_grad=False) + module.extra_repr = types.MethodType(_linear_extra_repr, module) + return module def gemlite_uintx_weight_only( @@ -857,29 +865,41 @@ def _int4_weight_only_transform( return module -def int8_weight_only(group_size=None): +class Int8WeightOnlyConfig(AOBaseConfig): """ - Applies int8 weight-only symmetric per-channel quantization to linear layers. + Configuration for applying int8 weight-only symmetric per-channel quantization to linear layers. """ - def apply_int8wo_quant(weight, group_size=None): - mapping_type = MappingType.SYMMETRIC - target_dtype = torch.int8 - eps = torch.finfo(torch.float32).eps - zero_point_dtype = torch.int64 - if group_size is None: - group_size = weight.shape[1] - block_size = (1, group_size) - return to_affine_quantized_intx( - weight, - mapping_type, - block_size, - target_dtype, - eps=eps, - zero_point_dtype=zero_point_dtype, - ) + group_size: Optional[int] = None + - return _get_linear_subclass_inserter(apply_int8wo_quant, group_size=group_size) +# for BC +int8_weight_only = Int8WeightOnlyConfig + + +@register_quantize_module_handler(Int8WeightOnlyConfig) +def _int8_weight_only_transform(module: torch.nn.Module, config: Int8WeightOnlyConfig): + group_size = config.group_size + weight = module.weight + + mapping_type = MappingType.SYMMETRIC + target_dtype = torch.int8 + eps = torch.finfo(torch.float32).eps + zero_point_dtype = torch.int64 + if group_size is None: + group_size = weight.shape[1] + block_size = (1, group_size) + new_weight = to_affine_quantized_intx( + weight, + mapping_type, + block_size, + target_dtype, + eps=eps, + zero_point_dtype=zero_point_dtype, + ) + module.weight = torch.nn.Parameter(new_weight, requires_grad=False) + module.extra_repr = types.MethodType(_linear_extra_repr, module) + return module def _int8_symm_per_token_reduced_range_quant(x: torch.Tensor) -> torch.Tensor: @@ -958,63 +978,76 @@ def _int4_symm_per_token_quant_cutlass(x: torch.Tensor) -> torch.Tensor: ) -def int8_dynamic_activation_int8_weight( - layout=PlainLayout(), - act_mapping_type=MappingType.SYMMETRIC, - weight_only_decode=False, -): +@dataclass +class Int8DynamicActivationInt8WeightConfig(AOBaseConfig): """ - Applies int8 dynamic symmetric per-token activation and int8 per-channel weight + Configuration for applying int8 dynamic symmetric per-token activation and int8 per-channel weight quantization to linear layers """ - def apply_int8_dynamic_activation_int8_weight_quant(weight): - in_features = weight.shape[1] - # int8 dynamic quantization only has benefit when in_feature > 16 - if in_features <= 16: - logger.info( - f"Skipping applying int8_dynamic_activation_int8_weight to weight of shape {weight.shape}" - f" because `in_feature` is <= 16: {in_features}" - ) - return weight + layout: Optional[Layout] = PlainLayout() + act_mapping_type: Optional[MappingType] = MappingType.SYMMETRIC + weight_only_decode: bool = False - # weight settings - mapping_type = MappingType.SYMMETRIC - weight_zero_point_domain = ZeroPointDomain.NONE - def get_weight_block_size(x): - return (1, x.shape[1]) +# for BC +int8_dynamic_activation_int8_weight = Int8DynamicActivationInt8WeightConfig - target_dtype = torch.int8 - eps = torch.finfo(torch.float32).eps - zero_point_dtype = torch.int64 - if weight_only_decode: - input_quant_func = _int8_symm_per_token_reduced_range_quant_noop_decode - else: - # input settings - if act_mapping_type == MappingType.SYMMETRIC: - input_quant_func = _int8_symm_per_token_reduced_range_quant - else: - input_quant_func = _int8_asymm_per_token_quant +@register_quantize_module_handler(Int8DynamicActivationInt8WeightConfig) +def _int8_dynamic_activation_int8_weight_transform( + module: torch.nn.Module, config: Int8DynamicActivationInt8WeightConfig +) -> torch.nn.Module: + layout = config.layout + act_mapping_type = config.act_mapping_type + weight_only_decode = config.weight_only_decode - block_size = get_weight_block_size(weight) - weight = to_affine_quantized_intx( - weight, - mapping_type, - block_size, - target_dtype, - eps=eps, - zero_point_dtype=zero_point_dtype, - _layout=layout, - zero_point_domain=weight_zero_point_domain, + weight = module.weight + + in_features = weight.shape[1] + # int8 dynamic quantization only has benefit when in_feature > 16 + if in_features <= 16: + logger.info( + f"Skipping applying int8_dynamic_activation_int8_weight to weight of shape {weight.shape}" + f" because `in_feature` is <= 16: {in_features}" ) - weight = to_linear_activation_quantized(weight, input_quant_func) - return weight + return module + + # weight settings + mapping_type = MappingType.SYMMETRIC + weight_zero_point_domain = ZeroPointDomain.NONE - return _get_linear_subclass_inserter( - apply_int8_dynamic_activation_int8_weight_quant + def get_weight_block_size(x): + return (1, x.shape[1]) + + target_dtype = torch.int8 + eps = torch.finfo(torch.float32).eps + zero_point_dtype = torch.int64 + + if weight_only_decode: + input_quant_func = _int8_symm_per_token_reduced_range_quant_noop_decode + else: + # input settings + if act_mapping_type == MappingType.SYMMETRIC: + input_quant_func = _int8_symm_per_token_reduced_range_quant + else: + input_quant_func = _int8_asymm_per_token_quant + + block_size = get_weight_block_size(weight) + weight = to_affine_quantized_intx( + weight, + mapping_type, + block_size, + target_dtype, + eps=eps, + zero_point_dtype=zero_point_dtype, + _layout=layout, + zero_point_domain=weight_zero_point_domain, ) + weight = to_linear_activation_quantized(weight, input_quant_func) + module.weight = torch.nn.Parameter(weight, requires_grad=False) + module.extra_repr = types.MethodType(_linear_extra_repr, module) + return module def int8_dynamic_activation_int8_semi_sparse_weight(): From 1c9c39faf7e69c130a9b89738c07fb2a35684d5f Mon Sep 17 00:00:00 2001 From: vasiliy Date: Mon, 10 Feb 2025 20:33:01 -0800 Subject: [PATCH 12/23] Update [ghstack-poisoned] --- test/dtypes/test_affine_quantized.py | 1 + test/quantization/test_quant_api.py | 23 +++- torchao/quantization/quant_api.py | 189 +++++++++++++++++---------- 3 files changed, 145 insertions(+), 68 deletions(-) diff --git a/test/dtypes/test_affine_quantized.py b/test/dtypes/test_affine_quantized.py index d26f1d8e04..616701f1e3 100644 --- a/test/dtypes/test_affine_quantized.py +++ b/test/dtypes/test_affine_quantized.py @@ -218,6 +218,7 @@ def test_flatten_unflatten(self, device, dtype): linear = torch.nn.Linear(128, 256, dtype=dtype, device=device) if isinstance(apply_quant, AOBaseConfig): quantize_(linear, apply_quant) + ql = linear else: # TODO(#1690): delete this once config migration is done ql = apply_quant(linear) diff --git a/test/quantization/test_quant_api.py b/test/quantization/test_quant_api.py index b81e928aa6..13f3800891 100644 --- a/test/quantization/test_quant_api.py +++ b/test/quantization/test_quant_api.py @@ -33,11 +33,14 @@ float8_dynamic_activation_float8_weight, float8_static_activation_float8_weight, float8_weight_only, + fpx_weight_only, + gemlite_uintx_weight_only, int4_dynamic_activation_int4_weight, int4_weight_only, int8_dynamic_activation_int4_weight, int8_dynamic_activation_int8_weight, int8_weight_only, + uintx_weight_only, ) from torchao.quantization.quant_primitives import MappingType from torchao.quantization.subclass import ( @@ -55,6 +58,13 @@ unwrap_tensor_subclass, ) +try: + import gemlite # noqa: F401 + + has_gemlite = True +except ModuleNotFoundError: + has_gemlite = False + def dynamic_quant(model, example_inputs): m = torch.export.export(model, example_inputs, strict=True).module() @@ -804,6 +814,9 @@ def test_int4wo_cpu(self, dtype, x_dim): int8_dynamic_activation_int8_weight(), int8_dynamic_activation_int4_weight(), int8_weight_only(), + fpx_weight_only(ebits=4, mbits=3), + gemlite_uintx_weight_only(), + uintx_weight_only(dtype=torch.uint4), ], ) def test_workflow_e2e_numerics(self, config): @@ -827,17 +840,23 @@ def test_workflow_e2e_numerics(self, config): and is_sm_at_least_90() ): return unittest.skip("only supported on CUDA capability 8.9, not greater") + elif isinstance(config, gemlite_uintx_weight_only) and not has_gemlite: + return unittest.skip("gemlite not available") # scale has to be moved to cuda here because the parametrization init # code happens before gating for cuda availability if isinstance(config, float8_static_activation_float8_weight): config.scale = config.scale.to("cuda") + dtype = torch.bfloat16 + if isinstance(config, gemlite_uintx_weight_only): + dtype = torch.float16 + # set up inputs - x = torch.randn(128, 128, device="cuda", dtype=torch.bfloat16) + x = torch.randn(128, 128, device="cuda", dtype=dtype) # TODO(future): model in float32 leads to error: https://gist.github.com/vkuzo/63b3bcd7818393021a6e3fb4ccf3c469 # is that expected? - m_ref = torch.nn.Sequential(torch.nn.Linear(128, 128)).cuda().bfloat16() + m_ref = torch.nn.Sequential(torch.nn.Linear(128, 128)).cuda().to(dtype) m_q = copy.deepcopy(m_ref) # quantize diff --git a/torchao/quantization/quant_api.py b/torchao/quantization/quant_api.py index 6c4ac40dc7..7991383e16 100644 --- a/torchao/quantization/quant_api.py +++ b/torchao/quantization/quant_api.py @@ -729,12 +729,8 @@ def _int4_dynamic_activation_int4_weight_transform( return module -def gemlite_uintx_weight_only( - group_size: Optional[int] = 64, - bit_width: int = 4, - packing_bitwidth: int = 32, - contiguous: Optional[bool] = None, -): +@dataclass +class GemliteUIntXWeightOnlyConfig(AOBaseConfig): """ applies weight only 4 or 8 bit integer quantization and utilizes the gemlite triton kernel and its associated weight packing format. This only works for fp16 models. 8 bit quantization is symmetric, 4 bit quantization is asymmetric. @@ -747,16 +743,39 @@ def gemlite_uintx_weight_only( `contiguous`: if set, the weight will be packed as specified. Leaving it as None lets gemlite determine the best choice. """ + group_size: Optional[int] = 64 + bit_width: int = 4 + packing_bitwidth: int = 32 + contiguous: Optional[bool] = None + + +# for BC +gemlite_uintx_weight_only = GemliteUIntXWeightOnlyConfig + + +@register_quantize_module_handler(GemliteUIntXWeightOnlyConfig) +def _gemlite_uintx_weight_only_transform( + module: torch.nn.Module, config: GemliteUIntXWeightOnlyConfig +): + group_size = config.group_size + bit_width = config.bit_width + packing_bitwidth = config.packing_bitwidth + contiguous = config.contiguous + + weight = module.weight + from torchao.dtypes.uintx.gemlite_layout import get_gemlite_aqt_kwargs use_hqq = True if bit_width == 4 else False - apply_fn = lambda weight: to_affine_quantized_intx( + new_weight = to_affine_quantized_intx( weight, **get_gemlite_aqt_kwargs( weight, group_size, bit_width, packing_bitwidth, contiguous, use_hqq ), ) - return _get_linear_subclass_inserter(apply_fn) + module.weight = torch.nn.Parameter(new_weight, requires_grad=False) + module.extra_repr = types.MethodType(_linear_extra_repr, module) + return module @dataclass @@ -1379,9 +1398,10 @@ def _float8_static_activation_float8_weight_transform( return module -def uintx_weight_only(dtype, group_size=64, pack_dim=-1, use_hqq=False): +@dataclass +class UIntXWeightOnlyConfig(AOBaseConfig): """ - Applies uintx weight-only asymmetric per-group quantization to linear layers, using uintx quantization where + Configuration for applying uintx weight-only asymmetric per-group quantization to linear layers, using uintx quantization where x is the number of bits specified by `dtype` Args: @@ -1391,6 +1411,28 @@ def uintx_weight_only(dtype, group_size=64, pack_dim=-1, use_hqq=False): `pack_dim`: the dimension we use for packing, defaults to -1 `use_hqq`: whether to use hqq algorithm or the default algorithm to quantize the weight """ + + dtype: torch.dtype + group_size: int = 64 + pack_dim: int = -1 + use_hqq: bool = False + + +# for BC +uintx_weight_only = UIntXWeightOnlyConfig + + +@register_quantize_module_handler(UIntXWeightOnlyConfig) +def _uintx_weight_only_transform( + module: torch.nn.Module, config: UIntXWeightOnlyConfig +): + dtype = config.dtype + group_size = config.group_size + pack_dim = config.pack_dim + use_hqq = config.use_hqq + + weight = module.weight + from torchao.quantization.quant_primitives import _DTYPE_TO_QVALUE_BOUNDS SUPPORTED_DTYPES = { @@ -1405,49 +1447,50 @@ def uintx_weight_only(dtype, group_size=64, pack_dim=-1, use_hqq=False): } assert dtype in SUPPORTED_DTYPES, f"Unsupported dtype for hqq: {dtype}" - def apply_uintx_weight_only_quant(weight, dtype): - mapping_type = MappingType.ASYMMETRIC - block_size = (1, group_size) - - if use_hqq: - if dtype == torch.uint4: - logger.warn( - "Recommended to use `int4_weight_only(group_size, use_hqq=True)` for the best performance" - ) - quant_min, quant_max = _DTYPE_TO_QVALUE_BOUNDS[dtype] - dtype = torch.uint8 - eps = None - zero_point_dtype = None - zero_point_domain = ZeroPointDomain.FLOAT - preserve_zero = False - _layout = PlainLayout() - else: - quant_min, quant_max = None, None - eps = torch.finfo(torch.float32).eps - zero_point_dtype = torch.int32 - zero_point_domain = ZeroPointDomain.INT - preserve_zero = True - _layout = UintxLayout(dtype=dtype, pack_dim=pack_dim) + mapping_type = MappingType.ASYMMETRIC + block_size = (1, group_size) - return to_affine_quantized_intx( - weight, - mapping_type, - block_size, - dtype, - quant_min=quant_min, - quant_max=quant_max, - eps=eps, - zero_point_dtype=zero_point_dtype, - zero_point_domain=zero_point_domain, - preserve_zero=preserve_zero, - _layout=_layout, - use_hqq=use_hqq, - ) + if use_hqq: + if dtype == torch.uint4: + logger.warn( + "Recommended to use `int4_weight_only(group_size, use_hqq=True)` for the best performance" + ) + quant_min, quant_max = _DTYPE_TO_QVALUE_BOUNDS[dtype] + dtype = torch.uint8 + eps = None + zero_point_dtype = None + zero_point_domain = ZeroPointDomain.FLOAT + preserve_zero = False + _layout = PlainLayout() + else: + quant_min, quant_max = None, None + eps = torch.finfo(torch.float32).eps + zero_point_dtype = torch.int32 + zero_point_domain = ZeroPointDomain.INT + preserve_zero = True + _layout = UintxLayout(dtype=dtype, pack_dim=pack_dim) - return _get_linear_subclass_inserter(apply_uintx_weight_only_quant, dtype=dtype) + new_weight = to_affine_quantized_intx( + weight, + mapping_type, + block_size, + dtype, + quant_min=quant_min, + quant_max=quant_max, + eps=eps, + zero_point_dtype=zero_point_dtype, + zero_point_domain=zero_point_domain, + preserve_zero=preserve_zero, + _layout=_layout, + use_hqq=use_hqq, + ) + module.weight = torch.nn.Parameter(new_weight, requires_grad=False) + module.extra_repr = types.MethodType(_linear_extra_repr, module) + return module -def fpx_weight_only(ebits: int, mbits: int): +@dataclass +class FPXWeightOnlyConfig(AOBaseConfig): """Sub-byte floating point dtypes defined by `ebits`: exponent bits and `mbits`: mantissa bits e.g. fp6_e3_m2, fp6_e2_m3, ... The packing format and kernels are from the fp6-llm paper: https://arxiv.org/abs/2401.14112 @@ -1458,26 +1501,40 @@ def fpx_weight_only(ebits: int, mbits: int): in the future """ - def apply_quant_llm(weight: torch.Tensor) -> torch.Tensor: - from torchao.dtypes import to_affine_quantized_fpx - from torchao.dtypes.floatx import FloatxTensorCoreLayout + ebits: int + mbits: int - assert ( - weight.dim() == 2 - ), f"floatx only works for 2-d Tensor, got: {weight.dim()}" - out_dim, in_dim = weight.shape - if (in_dim % 64 != 0) or (out_dim % 256 != 0): - logger.info( - f"Skipping floatx quantization float{ebits + mbits + 1}_{ebits}_{mbits} because " - f"the shape is not compatible with the kernel: in_dim={in_dim}, out_dim={out_dim} " - "expected in_dim % 64 == 0 and out_dim % 256 == 0" - ) - return weight - _layout = FloatxTensorCoreLayout(ebits, mbits) - return to_affine_quantized_fpx(weight, _layout) +# for BC +fpx_weight_only = FPXWeightOnlyConfig + + +@register_quantize_module_handler(FPXWeightOnlyConfig) +def _fpx_weight_only_transform( + module: torch.nn.Module, config: FPXWeightOnlyConfig +) -> torch.nn.Module: + ebits = config.ebits + mbits = config.mbits + weight = module.weight + + from torchao.dtypes import to_affine_quantized_fpx + from torchao.dtypes.floatx import FloatxTensorCoreLayout - return _get_linear_subclass_inserter(apply_quant_llm) + assert weight.dim() == 2, f"floatx only works for 2-d Tensor, got: {weight.dim()}" + out_dim, in_dim = weight.shape + if (in_dim % 64 != 0) or (out_dim % 256 != 0): + logger.info( + f"Skipping floatx quantization float{ebits + mbits + 1}_{ebits}_{mbits} because " + f"the shape is not compatible with the kernel: in_dim={in_dim}, out_dim={out_dim} " + "expected in_dim % 64 == 0 and out_dim % 256 == 0" + ) + return module + + _layout = FloatxTensorCoreLayout(ebits, mbits) + new_weight = to_affine_quantized_fpx(weight, _layout) + module.weight = torch.nn.Parameter(new_weight, requires_grad=False) + module.extra_repr = types.MethodType(_linear_extra_repr, module) + return module if TORCH_VERSION_AT_LEAST_2_5: From 1ff1f6ee9326695f67699773ef5cf1ed006c0f9a Mon Sep 17 00:00:00 2001 From: vasiliy Date: Mon, 10 Feb 2025 20:37:45 -0800 Subject: [PATCH 13/23] Update [ghstack-poisoned] --- test/dtypes/test_affine_quantized.py | 1 + torchao/quantization/quant_api.py | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/test/dtypes/test_affine_quantized.py b/test/dtypes/test_affine_quantized.py index d26f1d8e04..616701f1e3 100644 --- a/test/dtypes/test_affine_quantized.py +++ b/test/dtypes/test_affine_quantized.py @@ -218,6 +218,7 @@ def test_flatten_unflatten(self, device, dtype): linear = torch.nn.Linear(128, 256, dtype=dtype, device=device) if isinstance(apply_quant, AOBaseConfig): quantize_(linear, apply_quant) + ql = linear else: # TODO(#1690): delete this once config migration is done ql = apply_quant(linear) diff --git a/torchao/quantization/quant_api.py b/torchao/quantization/quant_api.py index 6c4ac40dc7..7aec7a3a7c 100644 --- a/torchao/quantization/quant_api.py +++ b/torchao/quantization/quant_api.py @@ -865,11 +865,11 @@ def _int4_weight_only_transform( return module +@dataclass class Int8WeightOnlyConfig(AOBaseConfig): """ Configuration for applying int8 weight-only symmetric per-channel quantization to linear layers. """ - group_size: Optional[int] = None From c2ed2da25edd2d557440910ce932d28bfd085bd4 Mon Sep 17 00:00:00 2001 From: vasiliy Date: Mon, 10 Feb 2025 20:39:22 -0800 Subject: [PATCH 14/23] Update [ghstack-poisoned] --- torchao/quantization/quant_api.py | 1 + 1 file changed, 1 insertion(+) diff --git a/torchao/quantization/quant_api.py b/torchao/quantization/quant_api.py index 7aec7a3a7c..b4f6c86252 100644 --- a/torchao/quantization/quant_api.py +++ b/torchao/quantization/quant_api.py @@ -870,6 +870,7 @@ class Int8WeightOnlyConfig(AOBaseConfig): """ Configuration for applying int8 weight-only symmetric per-channel quantization to linear layers. """ + group_size: Optional[int] = None From 397002e2fb320adeb7bd10c67925595fc5369ef9 Mon Sep 17 00:00:00 2001 From: vasiliy Date: Tue, 11 Feb 2025 06:00:42 -0800 Subject: [PATCH 15/23] Update [ghstack-poisoned] --- test/quantization/test_quant_api.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/quantization/test_quant_api.py b/test/quantization/test_quant_api.py index b81e928aa6..4cb0ee3579 100644 --- a/test/quantization/test_quant_api.py +++ b/test/quantization/test_quant_api.py @@ -848,7 +848,7 @@ def test_workflow_e2e_numerics(self, config): y_q = m_q(x) sqnr = compute_error(y_ref, y_q) - assert sqnr >= 20, f"SQNR {sqnr} is too low" + assert sqnr >= 16.5, f"SQNR {sqnr} is too low" class TestMultiTensorFlow(TestCase): From 5514a99ecb7a0d4f0e05f79803c2f06fc72c865a Mon Sep 17 00:00:00 2001 From: vasiliy Date: Tue, 11 Feb 2025 06:10:40 -0800 Subject: [PATCH 16/23] Update [ghstack-poisoned] --- test/dtypes/test_uintx.py | 6 +++--- test/hqq/test_hqq_affine.py | 8 +++----- 2 files changed, 6 insertions(+), 8 deletions(-) diff --git a/test/dtypes/test_uintx.py b/test/dtypes/test_uintx.py index da43253678..9bc983885e 100644 --- a/test/dtypes/test_uintx.py +++ b/test/dtypes/test_uintx.py @@ -150,7 +150,7 @@ def test_uintx_target_dtype(dtype): linear = torch.nn.Linear(128, 256, dtype=torch.bfloat16, device="cuda") # make sure it runs - uintx_weight_only(dtype)(linear) + quantize_(linear, uintx_weight_only(dtype)) linear(torch.randn(1, 128, dtype=torch.bfloat16, device="cuda")) @@ -165,7 +165,7 @@ def test_uintx_target_dtype_compile(dtype): linear = torch.nn.Linear(128, 256, dtype=torch.bfloat16, device="cuda") # make sure it runs - uintx_weight_only(dtype)(linear) + quantize_(linear, uintx_weight_only(dtype)) linear = torch.compile(linear) linear(torch.randn(1, 128, dtype=torch.bfloat16, device="cuda")) @@ -196,6 +196,6 @@ def test_uintx_model_size(dtype): ) bf16_size = get_model_size_in_bytes(linear) # make sure it runs - uintx_weight_only(dtype)(linear[0]) + quantize_(linear[0], uintx_weight_only(dtype)) quantized_size = get_model_size_in_bytes(linear) assert bf16_size * _dtype_to_ratio[dtype] == quantized_size diff --git a/test/hqq/test_hqq_affine.py b/test/hqq/test_hqq_affine.py index 096c9d26ba..d18ff59f99 100644 --- a/test/hqq/test_hqq_affine.py +++ b/test/hqq/test_hqq_affine.py @@ -53,12 +53,10 @@ def _eval_hqq(dtype): dummy_linear.weight.data = W if dtype == torch.uint4: config = int4_weight_only(group_size=max(block_size), use_hqq=True) - quantize_(dummy_linear, config) - q_tensor_hqq = dummy_linear.weight else: - q_tensor_hqq = uintx_weight_only( - dtype, group_size=max(block_size), use_hqq=True - )(dummy_linear).weight + config = uintx_weight_only(dtype, group_size=max(block_size), use_hqq=True) + quantize_(dummy_linear, config) + q_tensor_hqq = dummy_linear.weight quant_linear_layer = torch.nn.Linear( W.shape[1], W.shape[0], bias=False, device=W.device From 6684b39668b1dae71c797ba184a87e52b49c0feb Mon Sep 17 00:00:00 2001 From: vasiliy Date: Tue, 11 Feb 2025 14:18:07 -0800 Subject: [PATCH 17/23] Update [ghstack-poisoned] --- torchao/quantization/__init__.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torchao/quantization/__init__.py b/torchao/quantization/__init__.py index 3ec653df37..7b270b128e 100644 --- a/torchao/quantization/__init__.py +++ b/torchao/quantization/__init__.py @@ -50,7 +50,7 @@ Float8StaticActivationFloat8WeightConfig, Float8WeightOnlyConfig, FPXWeightOnlyConfig, - GemliteUIntxWeightOnlyConfig, + GemliteUIntXWeightOnlyConfig, Int4DynamicActivationInt4WeightConfig, Int4WeightOnlyConfig, Int8DynamicActivationInt4WeightConfig, @@ -140,7 +140,7 @@ "Float8StaticActivationFloat8WeightConfig", "UIntxWeightOnlyConfig", "FPXWeightOnlyConfig", - "GemliteUIntxWeightOnlyConfig", + "GemliteUIntXWeightOnlyConfig", # smooth quant - subject to change "get_scale", "SmoothFakeDynQuantMixin", From 4dcb3495c448c164da9bf43239f15782bdbc29ed Mon Sep 17 00:00:00 2001 From: vasiliy Date: Wed, 12 Feb 2025 05:58:30 -0800 Subject: [PATCH 18/23] Update [ghstack-poisoned] --- torchao/quantization/__init__.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torchao/quantization/__init__.py b/torchao/quantization/__init__.py index 7b270b128e..5f15a6bbbe 100644 --- a/torchao/quantization/__init__.py +++ b/torchao/quantization/__init__.py @@ -56,7 +56,7 @@ Int8DynamicActivationInt4WeightConfig, Int8DynamicActivationInt8WeightConfig, Int8WeightOnlyConfig, - UIntxWeightOnlyConfig, + UIntXWeightOnlyConfig, float8_dynamic_activation_float8_weight, float8_static_activation_float8_weight, float8_weight_only, @@ -138,7 +138,7 @@ "Float8WeightOnlyConfig", "Float8DynamicActivationFloat8WeightConfig", "Float8StaticActivationFloat8WeightConfig", - "UIntxWeightOnlyConfig", + "UIntXWeightOnlyConfig", "FPXWeightOnlyConfig", "GemliteUIntXWeightOnlyConfig", # smooth quant - subject to change From 3aaf5a0cdc12e960137491c3c0840888c51a8191 Mon Sep 17 00:00:00 2001 From: vasiliy Date: Wed, 12 Feb 2025 21:55:39 -0800 Subject: [PATCH 19/23] Update [ghstack-poisoned] --- torchao/dtypes/floatx/float8_layout.py | 1 + tutorials/calibration_flow/static_quant.py | 3 ++- 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/torchao/dtypes/floatx/float8_layout.py b/torchao/dtypes/floatx/float8_layout.py index 5a7e1924b3..656ebb61ae 100644 --- a/torchao/dtypes/floatx/float8_layout.py +++ b/torchao/dtypes/floatx/float8_layout.py @@ -253,6 +253,7 @@ def _linear_fp8_act_fp8_weight_impl( ): """Implements matmul between FP8 input and FP8 weight with compute using _scaled_mm""" scaled_mm_config = weight_tensor._layout.mm_config + assert scaled_mm_config is not None out_shape = get_out_shape(input_tensor.shape, weight_tensor.shape) # Weight tensor preprocessing diff --git a/tutorials/calibration_flow/static_quant.py b/tutorials/calibration_flow/static_quant.py index 4b7dfe405f..fd24a71189 100644 --- a/tutorials/calibration_flow/static_quant.py +++ b/tutorials/calibration_flow/static_quant.py @@ -163,12 +163,13 @@ def __init__( weight, weight_scale, weight_zero_point, block_size, self.target_dtype ) elif self.target_dtype == torch.float8_e4m3fn: + mm_config = Float8MMConfig(use_fast_accum=True) self.qweight = to_affine_quantized_floatx_static( weight, weight_scale, block_size, target_dtype, - Float8Layout(mm_config=None), + Float8Layout(mm_config=mm_config), ) else: raise ValueError(f"Unsupported target dtype {self.target_dtype}") From 3fd4cfc17ded95c9d9a0561daa738dc5bf305dcf Mon Sep 17 00:00:00 2001 From: vasiliy Date: Wed, 12 Feb 2025 22:15:36 -0800 Subject: [PATCH 20/23] Update [ghstack-poisoned] --- tutorials/calibration_flow/awq_like.py | 114 ++++++++++-------- tutorials/calibration_flow/gptq_like.py | 66 ++++++----- tutorials/calibration_flow/static_quant.py | 131 ++++++++++++--------- 3 files changed, 178 insertions(+), 133 deletions(-) diff --git a/tutorials/calibration_flow/awq_like.py b/tutorials/calibration_flow/awq_like.py index 5742b9b328..c047b8531e 100644 --- a/tutorials/calibration_flow/awq_like.py +++ b/tutorials/calibration_flow/awq_like.py @@ -8,11 +8,13 @@ """ import copy +from dataclasses import dataclass import torch import torch.nn.functional as F from torch import Tensor +from torchao.core.config import AOBaseConfig from torchao.dtypes import ( Float8Layout, to_affine_quantized_floatx_static, @@ -33,6 +35,9 @@ from torchao.quantization.quant_primitives import ( MappingType, ) +from torchao.quantization.transform_module import ( + register_quantize_module_handler, +) from torchao.quantization.utils import compute_error @@ -83,61 +88,72 @@ def replacement_fn(m): _replace_with_custom_fn_if_matches_filter(model, replacement_fn, _is_linear) +@dataclass +class ApplyAWQConfig(AOBaseConfig): + target_dtype: torch.dtype + + # converting observed linear module to linear module with quantzied weights (and quantized activations) # with tensor subclasses -def apply_awq(target_dtype: torch.dtype): - # target_dtype = torch.uint8 - def _apply_awq_to_linear(observed_linear): - # weight quantization - weight_scale, weight_zero_point = observed_linear.weight_obs.calculate_qparams() - - def weight_quant_func(weight): - block_size = (1, weight.shape[1]) - if target_dtype == torch.uint8: - return to_affine_quantized_intx_static( - weight, weight_scale, weight_zero_point, block_size, target_dtype - ) - elif target_dtype == torch.float8_e4m3fn: - return to_affine_quantized_floatx_static( - weight, - weight_scale, - block_size, - target_dtype, - Float8Layout(mm_config=None), - ) - else: - raise ValueError(f"Unsupported target dtype {target_dtype}") - - linear = torch.nn.Linear( - observed_linear.in_features, - observed_linear.out_features, - False, - device=observed_linear.weight.device, - dtype=observed_linear.weight.dtype, - ) - linear.weight = observed_linear.weight - linear.bias = observed_linear.bias - # activation quantization - # pretend this to be the equalization scale, in reality the `act_obs` should - # be an observer that can caluclate equalization scale - equalization_scale, _ = observed_linear.act_obs.calculate_qparams() - equalization_scale = torch.ones_like(equalization_scale) - linear.weight = torch.nn.Parameter( - weight_quant_func(linear.weight * equalization_scale), requires_grad=False - ) +@register_quantize_module_handler(ApplyAWQConfig) +def _apply_awq_transform( + module: torch.nn.Module, + config: ApplyAWQConfig, +): + target_dtype = config.target_dtype + observed_linear = module - linear.weight = torch.nn.Parameter( - to_weight_tensor_with_linear_activation_scale_metadata( - linear.weight, equalization_scale - ), - requires_grad=False, - ) + # target_dtype = torch.uint8 + # weight quantization + weight_scale, weight_zero_point = observed_linear.weight_obs.calculate_qparams() + + def weight_quant_func(weight): + block_size = (1, weight.shape[1]) + if target_dtype == torch.uint8: + return to_affine_quantized_intx_static( + weight, weight_scale, weight_zero_point, block_size, target_dtype + ) + elif target_dtype == torch.float8_e4m3fn: + return to_affine_quantized_floatx_static( + weight, + weight_scale, + block_size, + target_dtype, + Float8Layout(mm_config=None), + ) + else: + raise ValueError(f"Unsupported target dtype {target_dtype}") + + linear = torch.nn.Linear( + observed_linear.in_features, + observed_linear.out_features, + False, + device=observed_linear.weight.device, + dtype=observed_linear.weight.dtype, + ) + linear.weight = observed_linear.weight + linear.bias = observed_linear.bias + + # activation quantization + # pretend this to be the equalization scale, in reality the `act_obs` should + # be an observer that can caluclate equalization scale + equalization_scale, _ = observed_linear.act_obs.calculate_qparams() + equalization_scale = torch.ones_like(equalization_scale) - return linear + linear.weight = torch.nn.Parameter( + weight_quant_func(linear.weight * equalization_scale), requires_grad=False + ) + + linear.weight = torch.nn.Parameter( + to_weight_tensor_with_linear_activation_scale_metadata( + linear.weight, equalization_scale + ), + requires_grad=False, + ) - return _apply_awq_to_linear + return linear ######## Test ########## @@ -201,7 +217,7 @@ def test_awq(target_dtype: torch.dtype, mapping_type: MappingType): # quantized linear represented as an nn.Linear with modified tensor subclass weights # for both activation and weight quantization - quantize_(m, apply_awq(target_dtype), is_observed_linear) + quantize_(m, ApplyAWQConfig(target_dtype), is_observed_linear) print("quantized model (applying tensor subclass to weight):", m) after_quant = m(*example_inputs) assert compute_error(before_quant, after_quant) > 25 diff --git a/tutorials/calibration_flow/gptq_like.py b/tutorials/calibration_flow/gptq_like.py index 93c7e3c4ab..e4f28faf6f 100644 --- a/tutorials/calibration_flow/gptq_like.py +++ b/tutorials/calibration_flow/gptq_like.py @@ -33,6 +33,7 @@ import torch from torch.utils._pytree import tree_flatten, tree_unflatten +from torchao.core.config import AOBaseConfig from torchao.dtypes import ( to_affine_quantized_intx, to_affine_quantized_intx_static, @@ -47,6 +48,9 @@ to_linear_activation_quantized, ) from torchao.quantization.quant_api import _replace_with_custom_fn_if_matches_filter +from torchao.quantization.transform_module import ( + register_quantize_module_handler, +) from torchao.quantization.utils import compute_error torch.manual_seed(0) @@ -252,36 +256,42 @@ def _register_forward_pre_hook(module: torch.nn.Module): ) -# using a function to align with the API in quant_api -def apply_activation_static_weight_quant(): - def _apply_activation_static_weight_quant(observed_linear): - target_dtype = torch.uint8 - - # we can quantize the weight here as well +class ApplyActivationStaticWeightQuantConfig(AOBaseConfig): + pass - # activation quantization - act_scale, act_zero_point = ( - observed_linear.input_scale, - observed_linear.input_zp, - ) - input_quant_func = lambda x: to_affine_quantized_intx_static( - x, act_scale, act_zero_point, x.shape, target_dtype - ) - # for demo purpose only, we quantize the weight here - weight = observed_linear.weight - weight = to_affine_quantized_intx( - weight, MappingType.SYMMETRIC, (1, weight.shape[-1]), torch.int8 - ) - observed_linear.weight = torch.nn.Parameter( - to_linear_activation_quantized(weight, input_quant_func), - requires_grad=False, - ) - del observed_linear.input_scale - del observed_linear.input_zp - return observed_linear +# using a function to align with the API in quant_api +@register_quantize_module_handler(ApplyActivationStaticWeightQuantConfig) +def _apply_activation_static_weight_quant_transform( + module: torch.nn.Module, + config: ApplyActivationStaticWeightQuantConfig, +): + observed_linear = module + target_dtype = torch.uint8 + + # we can quantize the weight here as well + + # activation quantization + act_scale, act_zero_point = ( + observed_linear.input_scale, + observed_linear.input_zp, + ) + input_quant_func = lambda x: to_affine_quantized_intx_static( + x, act_scale, act_zero_point, x.shape, target_dtype + ) + # for demo purpose only, we quantize the weight here + weight = observed_linear.weight + weight = to_affine_quantized_intx( + weight, MappingType.SYMMETRIC, (1, weight.shape[-1]), torch.int8 + ) + observed_linear.weight = torch.nn.Parameter( + to_linear_activation_quantized(weight, input_quant_func), + requires_grad=False, + ) - return _apply_activation_static_weight_quant + del observed_linear.input_scale + del observed_linear.input_zp + return observed_linear example_inputs = (torch.randn(32, 64),) @@ -298,7 +308,7 @@ def _apply_activation_static_weight_quant(observed_linear): # just quantizing activation since we only observed quantization, this could be extended to support # quantizing weight as well -quantize_(m, apply_activation_static_weight_quant(), _is_linear) +quantize_(m, ApplyActivationStaticWeightQuantConfig(), _is_linear) for l in m.modules(): if isinstance(l, torch.nn.Linear): assert isinstance(l.weight, LinearActivationQuantizedTensor) diff --git a/tutorials/calibration_flow/static_quant.py b/tutorials/calibration_flow/static_quant.py index fd24a71189..1ebce411d3 100644 --- a/tutorials/calibration_flow/static_quant.py +++ b/tutorials/calibration_flow/static_quant.py @@ -3,11 +3,13 @@ """ import copy +from dataclasses import dataclass import torch import torch.nn.functional as F from torch import Tensor +from torchao.core.config import AOBaseConfig from torchao.dtypes import ( Float8Layout, to_affine_quantized_floatx_static, @@ -26,6 +28,9 @@ from torchao.quantization.quant_primitives import ( MappingType, ) +from torchao.quantization.transform_module import ( + register_quantize_module_handler, +) from torchao.quantization.utils import compute_error from torchao.utils import is_sm_at_least_90 @@ -77,66 +82,74 @@ def replacement_fn(m): _replace_with_custom_fn_if_matches_filter(model, replacement_fn, _is_linear) -# converting observed linear module to linear module with quantzied weights (and quantized activations) -# with tensor subclasses -def apply_static_quant(target_dtype: torch.dtype): - # target_dtype = torch.uint8 - def _apply_static_quant_to_linear(observed_linear): - # weight quantization - weight_scale, weight_zero_point = observed_linear.weight_obs.calculate_qparams() - - def weight_quant_func(weight): - block_size = (1, weight.shape[1]) - if target_dtype == torch.uint8: - return to_affine_quantized_intx_static( - weight, weight_scale, weight_zero_point, block_size, target_dtype - ) - elif target_dtype == torch.float8_e4m3fn: - mm_config = Float8MMConfig(use_fast_accum=True) - return to_affine_quantized_floatx_static( - weight, - weight_scale, - block_size, - target_dtype, - Float8Layout(mm_config=mm_config), - ) - else: - raise ValueError(f"Unsupported target dtype {target_dtype}") - - linear = torch.nn.Linear( - observed_linear.in_features, - observed_linear.out_features, - False, - device=observed_linear.weight.device, - dtype=observed_linear.weight.dtype, - ) - linear.weight = observed_linear.weight - linear.bias = observed_linear.bias +@dataclass +class ApplyStaticQuantConfig(AOBaseConfig): + target_dtype: torch.dtype - linear.weight = torch.nn.Parameter( - weight_quant_func(linear.weight), requires_grad=False - ) - # activation quantization - act_scale, act_zero_point = observed_linear.act_obs.calculate_qparams() +# converting observed linear module to linear module with quantzied weights (and quantized activations) +# with tensor subclasses +@register_quantize_module_handler(ApplyStaticQuantConfig) +def _apply_static_quant_transform( + module: torch.nn.Module, + config: ApplyStaticQuantConfig, +): + target_dtype = config.target_dtype + observed_linear = module + + # weight quantization + weight_scale, weight_zero_point = observed_linear.weight_obs.calculate_qparams() + + def weight_quant_func(weight): + block_size = (1, weight.shape[1]) if target_dtype == torch.uint8: - input_quant_func = lambda x: to_affine_quantized_intx_static( - x, act_scale, act_zero_point, x.shape, target_dtype + return to_affine_quantized_intx_static( + weight, weight_scale, weight_zero_point, block_size, target_dtype ) elif target_dtype == torch.float8_e4m3fn: - input_quant_func = lambda x: to_affine_quantized_floatx_static( - x, act_scale, x.shape, target_dtype, Float8Layout(mm_config=None) + mm_config = Float8MMConfig(use_fast_accum=True) + return to_affine_quantized_floatx_static( + weight, + weight_scale, + block_size, + target_dtype, + Float8Layout(mm_config=mm_config), ) else: raise ValueError(f"Unsupported target dtype {target_dtype}") - linear.weight = torch.nn.Parameter( - to_linear_activation_quantized(linear.weight, input_quant_func), - requires_grad=False, - ) - return linear + linear = torch.nn.Linear( + observed_linear.in_features, + observed_linear.out_features, + False, + device=observed_linear.weight.device, + dtype=observed_linear.weight.dtype, + ) + linear.weight = observed_linear.weight + linear.bias = observed_linear.bias - return _apply_static_quant_to_linear + linear.weight = torch.nn.Parameter( + weight_quant_func(linear.weight), requires_grad=False + ) + + # activation quantization + act_scale, act_zero_point = observed_linear.act_obs.calculate_qparams() + if target_dtype == torch.uint8: + input_quant_func = lambda x: to_affine_quantized_intx_static( + x, act_scale, act_zero_point, x.shape, target_dtype + ) + elif target_dtype == torch.float8_e4m3fn: + input_quant_func = lambda x: to_affine_quantized_floatx_static( + x, act_scale, x.shape, target_dtype, Float8Layout(mm_config=None) + ) + else: + raise ValueError(f"Unsupported target dtype {target_dtype}") + linear.weight = torch.nn.Parameter( + to_linear_activation_quantized(linear.weight, input_quant_func), + requires_grad=False, + ) + + return linear # alternative for converting observed linear module to quantized linear module @@ -210,11 +223,17 @@ def from_observed(cls, observed_linear, target_dtype): return quantized_linear -def apply_static_quant2(target_dtype: torch.dtype): - def _apply_static_quant2(observed_linear): - return QuantizedLinear.from_observed(observed_linear, target_dtype) +@dataclass +class ApplyStaticQuantConfig2(AOBaseConfig): + target_dtype: torch.dtype + - return _apply_static_quant2 +@register_quantize_module_handler(ApplyStaticQuantConfig2) +def apply_static_quant( + module: torch.nn.Module, + config: ApplyStaticQuantConfig2, +): + return QuantizedLinear.from_observed(module, config.target_dtype) class ToyLinearModel(torch.nn.Module): @@ -281,14 +300,14 @@ def test_static_quant(target_dtype: torch.dtype, mapping_type: MappingType): # quantized linear represented as an nn.Linear with modified tensor subclass weights # for both activation and weight quantization - quantize_(m, apply_static_quant(target_dtype), is_observed_linear) + quantize_(m, ApplyStaticQuantConfig(target_dtype), is_observed_linear) print("quantized model (applying tensor subclass to weight):", m) after_quant = m(*example_inputs) assert compute_error(before_quant, after_quant) > 25 print("test passed") # quantized linear as a standalone module - quantize_(m2, apply_static_quant2(target_dtype), is_observed_linear) + quantize_(m2, ApplyStaticQuantConfig2(target_dtype), is_observed_linear) print("quantized model (quantized module):", m2) after_quant = m2(*example_inputs) assert compute_error(before_quant, after_quant) > 25 From e0124f7ec170e05e43649e5eb9998f95dfc374d0 Mon Sep 17 00:00:00 2001 From: vasiliy Date: Wed, 12 Feb 2025 22:38:04 -0800 Subject: [PATCH 21/23] Update [ghstack-poisoned] --- README.md | 12 ++++---- torchao/quantization/README.md | 44 +++++++++++++++--------------- torchao/quantization/qat/README.md | 18 ++++++------ 3 files changed, 37 insertions(+), 37 deletions(-) diff --git a/README.md b/README.md index 0da273f91c..a35aca2e88 100644 --- a/README.md +++ b/README.md @@ -29,16 +29,16 @@ For inference, we have the option of ```python from torchao.quantization.quant_api import ( quantize_, - int8_dynamic_activation_int8_weight, - int4_weight_only, - int8_weight_only + Int8DynamicActivationInt8WeightConfig, + Int4WeightOnlyConfig, + Int8WeightOnlyConfig ) -quantize_(m, int4_weight_only()) +quantize_(m, Int4WeightOnlyConfig()) ``` -For gpt-fast `int4_weight_only()` is the best option at bs=1 as it **2x the tok/s and reduces the VRAM requirements by about 65%** over a torch.compiled baseline. +For gpt-fast `Int4WeightOnlyConfig()` is the best option at bs=1 as it **2x the tok/s and reduces the VRAM requirements by about 65%** over a torch.compiled baseline. -If you don't have enough VRAM to quantize your entire model on GPU and you find CPU quantization to be too slow then you can use the device argument like so `quantize_(model, int8_weight_only(), device="cuda")` which will send and quantize each layer individually to your GPU. +If you don't have enough VRAM to quantize your entire model on GPU and you find CPU quantization to be too slow then you can use the device argument like so `quantize_(model, Int8WeightOnlyConfig(), device="cuda")` which will send and quantize each layer individually to your GPU. If you see slowdowns with any of these techniques or you're unsure which option to use, consider using [autoquant](./torchao/quantization/README.md#autoquantization) which will automatically profile layers and pick the best way to quantize each layer. diff --git a/torchao/quantization/README.md b/torchao/quantization/README.md index ace4d8c14c..655a942718 100644 --- a/torchao/quantization/README.md +++ b/torchao/quantization/README.md @@ -82,7 +82,7 @@ model(input) When used as in the example above, when the `autoquant` api is called alongside torch.compile, autoquant sets up the model so that when its run on the next input, the autoquantization and torch.compile processes leave you with a heavily optimized model. -When `model(input)` is called, (under the hood) the tool does a preliminary run with the input where each linear layer keeps track of the different shapes and types of activations that it sees. Once the preliminary run is complete, the next step is to check each linear layer and benchmark the tracked shapes for different types of quantization techniques in order to pick the fastest one, attempting to take into account fusions where possible. Finally once the best class is found for each layer, the next step is to apply the necessary quantization technique to each layer, before finally allowing the normal `torch.compile` process to occur on the now quantized model. By default the api only uses int8 techniques, i.e. it chooses between no quantization, int8 dynamic quantization and int8 weight only quantization for each layer, though there is also an option add int4 quantization which can be used for maximum performance or to avoid perf regressions from `int4_weight_only()` since for certain (compute bound) regimes, int4 weight only quantization can be very slow. +When `model(input)` is called, (under the hood) the tool does a preliminary run with the input where each linear layer keeps track of the different shapes and types of activations that it sees. Once the preliminary run is complete, the next step is to check each linear layer and benchmark the tracked shapes for different types of quantization techniques in order to pick the fastest one, attempting to take into account fusions where possible. Finally once the best class is found for each layer, the next step is to apply the necessary quantization technique to each layer, before finally allowing the normal `torch.compile` process to occur on the now quantized model. By default the api only uses int8 techniques, i.e. it chooses between no quantization, int8 dynamic quantization and int8 weight only quantization for each layer, though there is also an option add int4 quantization which can be used for maximum performance or to avoid perf regressions from `Int4WeightOnlyConfig()` since for certain (compute bound) regimes, int4 weight only quantization can be very slow. Sometimes it is desirable to reuse a quantization plan that `autoquant` came up with. `torchao.quantization.AUTOQUANT_CACHE` is a dictionary holding autoquant's benchmark results. We can save it and restore it later, which will cause `autoquant` to choose the same quantization methods. @@ -109,13 +109,13 @@ be applied individually. While there are a large variety of quantization apis, t ```python # for torch 2.4+ -from torchao.quantization import quantize_, int4_weight_only +from torchao.quantization import quantize_, Int4WeightOnlyConfig group_size = 32 # you can enable [hqq](https://github.com/mobiusml/hqq/tree/master) quantization which is expected to improves accuracy through -# use_hqq flag for `int4_weight_only` quantization +# use_hqq flag for `Int4WeightOnlyConfig` quantization use_hqq = False -quantize_(model, int4_weight_only(group_size=group_size, use_hqq=use_hqq)) +quantize_(model, Int4WeightOnlyConfig(group_size=group_size, use_hqq=use_hqq)) # for torch 2.2.2 and 2.3 from torchao.quantization.quant_api import change_linear_weights_to_int4_woqtensors @@ -128,8 +128,8 @@ Note: The quantization error incurred by applying int4 quantization to your mode ```python # for torch 2.4+ -from torchao.quantization import quantize_, int8_weight_only -quantize_(model, int8_weight_only()) +from torchao.quantization import quantize_, Int8WeightOnlyConfig +quantize_(model, Int8WeightOnlyConfig()) # for torch 2.2.2 and 2.3 from torchao.quantization.quant_api import change_linear_weights_to_int8_woqtensors @@ -140,8 +140,8 @@ change_linear_weights_to_int8_woqtensors(model) ```python # for torch 2.4+ -from torchao.quantization import quantize_, int8_dynamic_activation_int8_weight -quantize_(model, int8_dynamic_activation_int8_weight()) +from torchao.quantization import quantize_, Int8DynamicActivationInt8WeightConfig +quantize_(model, Int8DynamicActivationInt8WeightConfig()) # for torch 2.2.2 and 2.3 from torchao.quantization.quant_api import change_linear_weights_to_int8_dqtensors @@ -152,8 +152,8 @@ change_linear_weights_to_int8_dqtensors(model) ```python # for torch 2.5+ -from torchao.quantization import quantize_, float8_weight_only -quantize_(model, float8_weight_only()) +from torchao.quantization import quantize_, Float8WeightOnlyConfig +quantize_(model, Float8WeightOnlyConfig()) ``` Supports all dtypes for original weight and activation. This API is only tested on H100. Hardware with CUDA compute capability 8.9 or greater is required. @@ -162,8 +162,8 @@ Supports all dtypes for original weight and activation. This API is only tested ```python # for torch 2.4+ -from torchao.quantization import quantize_, float8_dynamic_activation_float8_weight, PerTensor -quantize_(model, float8_dynamic_activation_float8_weight(granularity=PerTensor())) +from torchao.quantization import quantize_, Float8DynamicActivationFloat8WeightConfig, PerTensor +quantize_(model, Float8DynamicActivationFloat8WeightConfig(granularity=PerTensor())) ``` Supports all dtypes for original weight and activation. This API is only tested on H100. Hardware with CUDA compute capability 8.9 or greater is required. @@ -172,8 +172,8 @@ Supports all dtypes for original weight and activation. This API is only tested ```python # for torch 2.5+ -from torchao.quantization import quantize_, PerRow, float8_dynamic_activation_float8_weight -quantize_(model, float8_dynamic_activation_float8_weight(granularity=PerRow())) +from torchao.quantization import quantize_, PerRow, Float8DynamicActivationFloat8WeightConfig +quantize_(model, Float8DynamicActivationFloat8WeightConfig(granularity=PerRow())) ``` Per-row scaling is only supported for bfloat16 weight and activation. This API is only tested on H100. Hardware with CUDA compute capability 8.9 or greater is required. @@ -182,14 +182,14 @@ Per-row scaling is only supported for bfloat16 weight and activation. This API i ```python # for torch 2.4+ -from torchao.quantization import quantize_, fpx_weight_only -quantize_(model, fpx_weight_only(3, 2)) +from torchao.quantization import quantize_, FPXWeightOnlyConfig +quantize_(model, FPXWeightOnlyConfig(3, 2)) ``` You can find more information [here](../dtypes/floatx/README.md). It should be noted where most other TorchAO apis and benchmarks have focused on applying techniques on top of a bf16 model, performance, fp6 works primarily with the fp16 dtype. ## Affine Quantization Details -Affine quantization refers to the type of quantization that maps from high precision floating point numbers to quantized numbers (low precision integer or floating point dtypes) with an affine transformation, i.e.: `quantized_val = high_preicsion_float_val / scale + zero_point` where `scale` and `zero_point` are quantization parameters for some granularity and based on some data (also some dtypes may not require a `zero_point`). Each of the techniques in the above section qualify as Affine Quantization. +Affine quantization refers to the type of quantization that maps from high precision floating point numbers to quantized numbers (low precision integer or floating point dtypes) with an affine transformation, i.e.: `quantized_val = high_precision_float_val / scale + zero_point` where `scale` and `zero_point` are quantization parameters for some granularity and based on some data (also some dtypes may not require a `zero_point`). Each of the techniques in the above section qualify as Affine Quantization. ### Quantization Primitives We used to have different quantize and dequantize operators for quantization with different granularities. But in the end these can all be expressed with a `block_size` argument with different settings, so we unified existing quant primitives to `choose_qparams_affine`, `quantize_affine` and `dequantize_affine` that can represent symmetric/asymmetric per tensor/channel/token/channel_group quantization, this can be used to implement the unified quantized tensor subclass. @@ -200,7 +200,7 @@ Note: these primitive ops supports two "types" of quantization, distinguished by We also have a unified quantized tensor subclass that implements how to get a quantized tensor from floating point tensor and what does it mean to call linear ops on an instance of the tensor, e.g. `F.linear` and `aten.addmm`, with this we could dispatch to different operators (e.g. `int4mm` op) based on device (cpu, cuda) and quantization settings (`int4`, `int8`) and also packing formats (e.g. format optimized for cpu int4 mm kernel) #### Layouts -We extended the `layout` concept to represent different packing formats for a tensor. `AffineQuantizedTensor` supports `plain` and `tensor_core_tiled` layout. `plain` layout is used for `int8_weight_only` and `int8_dynamic_activation_int8_weight` and also as a default layout. `tensor_core_tiled` layout is used for `int4_weight_only` quantization and is packing the weights in a format that is compatible with tinygemm [int4mm](https://github.com/pytorch/pytorch/blob/39357ba06f48cda7d293a4995aa5eba2a46598b5/aten/src/ATen/native/native_functions.yaml#L4138) kernels. +We extended the `layout` concept to represent different packing formats for a tensor. `AffineQuantizedTensor` supports `plain` and `tensor_core_tiled` layout. `plain` layout is used for workflows backing `Int8WeightOnlyConfig` and `Int8DynamicActivationInt8WeightConfig` and also as a default layout. `tensor_core_tiled` layout is used for workflows backing `Int4WeightOnlyConfig` quantization and is packing the weights in a format that is compatible with tinygemm [int4mm](https://github.com/pytorch/pytorch/blob/39357ba06f48cda7d293a4995aa5eba2a46598b5/aten/src/ATen/native/native_functions.yaml#L4138) kernels. ### Zero Point Domains ```ZeroPointDomain``` is used to control the data types of zero points. ```ZeroPointDomain.None``` means zero_point is None, ```ZeroPointDomain.FLOAT``` means zero_point is in the floating point domain and ```ZeroPointDomain.INT``` means integer domain. For detailed implementation of different zero point data types, refer to [the reference implementation](../../test/quantization/test_quant_primitives.py). @@ -223,7 +223,7 @@ from torchao.dtypes import to_affine_quantized_intx import copy from torchao.quantization.quant_api import ( quantize_, - int4_weight_only, + Int4WeightOnlyConfig, ) class ToyLinearModel(torch.nn.Module): @@ -249,9 +249,9 @@ m_bf16 = torch.compile(m_bf16, mode='max-autotune') # apply int4 weight only quant (compatible with tinygemm int4 weight only quant mm kernel in torchao) group_size = 32 # only works for torch 2.4+ -quantize_(m, int4_weight_only(group_size=group_size)) +quantize_(m, Int4WeightOnlyConfig(group_size=group_size)) ## If different zero_point_domain needed -# quantize_(m, int4_weight_only(group_size=group_size), zero_point_domain=ZeroPointDomain.FLOAT) +# quantize_(m, Int4WeightOnlyConfig(group_size=group_size, zero_point_domain=ZeroPointDomain.FLOAT)) # temporary workaround for tensor subclass + torch.compile # NOTE: this is only need for torch version < 2.5+ @@ -360,7 +360,7 @@ We're trying to develop kernels for low bit quantization for intx quantization f | | uintx-4-64-hqq | 8.124 | 47.85 | 213.24 | 11.85 | 4.46 | | | uintx-2-8-hqq | 39.605 | 34.83 | 261.42 | 14.99 | 7.51 | -You try can out these apis with the `quantize_` api as above alongside the constructor `uintx_weight_only` an example can be found in in `torchao/_models/llama/generate.py`. +You try can out these apis with the `quantize_` api as above alongside the config `UIntXWeightOnlyConfig`. An example can be found in in `torchao/_models/llama/generate.py`. ### int8_dynamic_activation_intx_weight Quantization We have kernels that do 8-bit dynamic quantization of activations and uintx groupwise quantization of weights. These kernels are experimental and can only be run on a device with an ARM CPU (e.g., a Mac computers with Apple silicon). The benchmarks below were run on an M1 Mac Pro, with 8 perf cores, and 2 efficiency cores, and 32GB of RAM. In all cases, torch.compile was used. diff --git a/torchao/quantization/qat/README.md b/torchao/quantization/qat/README.md index 813b628af7..0f024dbf61 100644 --- a/torchao/quantization/qat/README.md +++ b/torchao/quantization/qat/README.md @@ -71,9 +71,9 @@ def train_loop(m: torch.nn.Module): The recommended way to run QAT in torchao is through the `quantize_` API: 1. **Prepare:** specify how weights and/or activations are to be quantized through -[`FakeQuantizeConfig`](https://github.com/pytorch/ao/blob/v0.7.0/torchao/quantization/qat/api.py#L29) and passing these to [`intx_quantization_aware_training`](https://github.com/pytorch/ao/blob/cedadc741954f47a9e9efac2aa584701f125bc73/torchao/quantization/qat/api.py#L242) +[`FakeQuantizeConfig`](https://github.com/pytorch/ao/blob/v0.7.0/torchao/quantization/qat/api.py#L29) and passing these to [`IntXQuantizationAwareTrainingConfig`](https://github.com/pytorch/ao/blob/cedadc741954f47a9e9efac2aa584701f125bc73/torchao/quantization/qat/api.py#L242) 2. **Convert:** quantize the model using the standard post-training quantization (PTQ) -functions such as [`int8_dynamic_activation_int4_weight`](https://github.com/pytorch/ao/blob/v0.7.0/torchao/quantization/quant_api.py#L606) +functions such as [`Int8DynamicActivationInt4WeightConfig`](https://github.com/pytorch/ao/blob/v0.7.0/torchao/quantization/quant_api.py#L606) For example: @@ -81,12 +81,12 @@ For example: ```python from torchao.quantization import ( quantize_, - int8_dynamic_activation_int4_weight, + Int8DynamicActivationInt4WeightConfig, ) from torchao.quantization.qat import ( FakeQuantizeConfig, - from_intx_quantization_aware_training, - intx_quantization_aware_training, + FromIntXQuantizationAwareTrainingConfig, + IntXQuantizationAwareTrainingConfig, ) model = get_model() @@ -96,7 +96,7 @@ activation_config = FakeQuantizeConfig(torch.int8, "per_token", is_symmetric=Fal weight_config = FakeQuantizeConfig(torch.int4, group_size=32) quantize_( model, - intx_quantization_aware_training(activation_config, weight_config), + IntXQuantizationAwareTrainingConfig(activation_config, weight_config), ) # train @@ -105,8 +105,8 @@ train_loop(model) # convert: transform fake quantization ops into actual quantized ops # swap `FakeQuantizedLinear` back to `torch.nn.Linear` and inserts # quantized activation and weight tensor subclasses -quantize_(model, from_intx_quantization_aware_training()) -quantize_(model, int8_dynamic_activation_int4_weight(group_size=32)) +quantize_(model, FromIntXQuantizationAwareTrainingConfig()) +quantize_(model, Int8DynamicActivationInt4WeightConfig(group_size=32)) # inference or generate ``` @@ -117,7 +117,7 @@ the following with a filter function during the prepare step: ``` quantize_( m, - intx_quantization_aware_training(weight_config=weight_config), + IntXQuantizationAwareTrainingConfig(weight_config=weight_config), filter_fn=lambda m, _: isinstance(m, torch.nn.Embedding), ) ``` From 4de0f68329d7084dcc6275905e5641dbff50d8c8 Mon Sep 17 00:00:00 2001 From: vasiliy Date: Thu, 13 Feb 2025 16:15:26 -0800 Subject: [PATCH 22/23] Update [ghstack-poisoned] --- README.md | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/README.md b/README.md index a35aca2e88..6ba8a85ed8 100644 --- a/README.md +++ b/README.md @@ -67,8 +67,8 @@ from torchao.quantization import ( ) from torchao.quantization.qat import ( FakeQuantizeConfig, - from_intx_quantization_aware_training, - intx_quantization_aware_training, + FromIntXQuantizationAwareTrainingConfig, + IntXQuantizationAwareTrainingConfig, ) # Insert fake quantization @@ -76,13 +76,13 @@ activation_config = FakeQuantizeConfig(torch.int8, "per_token", is_symmetric=Fal weight_config = FakeQuantizeConfig(torch.int4, group_size=32) quantize_( my_model, - intx_quantization_aware_training(activation_config, weight_config), + IntXQuantizationAwareTrainingConfig(activation_config, weight_config), ) # Run training... (not shown) # Convert fake quantization to actual quantized operations -quantize_(my_model, from_intx_quantization_aware_training()) +quantize_(my_model, FromIntXQuantizationAwareTrainingConfig()) quantize_(my_model, int8_dynamic_activation_int4_weight(group_size=32)) ``` @@ -139,7 +139,7 @@ The best example we have combining the composability of lower bit dtype with com We've added support for authoring and releasing [custom ops](./torchao/csrc/) that do not graph break with `torch.compile()` so if you love writing kernels but hate packaging them so they work all operating systems and cuda versions, we'd love to accept contributions for your custom ops. We have a few examples you can follow -1. [fp6](torchao/dtypes/floatx) for 2x faster inference over fp16 with an easy to use API `quantize_(model, fpx_weight_only(3, 2))` +1. [fp6](torchao/dtypes/floatx) for 2x faster inference over fp16 with an easy to use API `quantize_(model, FPXWeightOnlyConfig(3, 2))` 2. [2:4 Sparse Marlin GEMM](https://github.com/pytorch/ao/pull/733) 2x speedups for FP16xINT4 kernels even at batch sizes up to 256 3. [int4 tinygemm unpacker](https://github.com/pytorch/ao/pull/415) which makes it easier to switch quantized backends for inference From 060cda8e5083aa7583ef182aee583d667c2abd18 Mon Sep 17 00:00:00 2001 From: vasiliy Date: Thu, 13 Feb 2025 16:18:02 -0800 Subject: [PATCH 23/23] Update [ghstack-poisoned] --- README.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index 6ba8a85ed8..e3cdc60aba 100644 --- a/README.md +++ b/README.md @@ -63,7 +63,7 @@ Post-training quantization can result in a fast and compact model, but may also ```python from torchao.quantization import ( quantize_, - int8_dynamic_activation_int4_weight, + Int8DynamicActivationInt4WeightConfig, ) from torchao.quantization.qat import ( FakeQuantizeConfig, @@ -83,7 +83,7 @@ quantize_( # Convert fake quantization to actual quantized operations quantize_(my_model, FromIntXQuantizationAwareTrainingConfig()) -quantize_(my_model, int8_dynamic_activation_int4_weight(group_size=32)) +quantize_(my_model, Int8DynamicActivationInt4WeightConfig(group_size=32)) ``` ### Float8