From 89aaa600253338dabe52a374c0fe404798d2a244 Mon Sep 17 00:00:00 2001 From: vasiliy Date: Thu, 13 Feb 2025 16:15:25 -0800 Subject: [PATCH] config migration: fpx, gemlite, uintx Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: ghstack-source-id: 225102ad641ad0caefd8455e2cdb277fae400ca1 ghstack-comment-id: 2649778077 Pull Request resolved: https://github.com/pytorch/ao/pull/1697 --- test/dtypes/test_uintx.py | 6 +- test/hqq/test_hqq_affine.py | 8 +- test/quantization/test_quant_api.py | 23 +++- torchao/quantization/__init__.py | 6 + torchao/quantization/quant_api.py | 189 ++++++++++++++++++---------- 5 files changed, 156 insertions(+), 76 deletions(-) diff --git a/test/dtypes/test_uintx.py b/test/dtypes/test_uintx.py index da43253678..9bc983885e 100644 --- a/test/dtypes/test_uintx.py +++ b/test/dtypes/test_uintx.py @@ -150,7 +150,7 @@ def test_uintx_target_dtype(dtype): linear = torch.nn.Linear(128, 256, dtype=torch.bfloat16, device="cuda") # make sure it runs - uintx_weight_only(dtype)(linear) + quantize_(linear, uintx_weight_only(dtype)) linear(torch.randn(1, 128, dtype=torch.bfloat16, device="cuda")) @@ -165,7 +165,7 @@ def test_uintx_target_dtype_compile(dtype): linear = torch.nn.Linear(128, 256, dtype=torch.bfloat16, device="cuda") # make sure it runs - uintx_weight_only(dtype)(linear) + quantize_(linear, uintx_weight_only(dtype)) linear = torch.compile(linear) linear(torch.randn(1, 128, dtype=torch.bfloat16, device="cuda")) @@ -196,6 +196,6 @@ def test_uintx_model_size(dtype): ) bf16_size = get_model_size_in_bytes(linear) # make sure it runs - uintx_weight_only(dtype)(linear[0]) + quantize_(linear[0], uintx_weight_only(dtype)) quantized_size = get_model_size_in_bytes(linear) assert bf16_size * _dtype_to_ratio[dtype] == quantized_size diff --git a/test/hqq/test_hqq_affine.py b/test/hqq/test_hqq_affine.py index 096c9d26ba..d18ff59f99 100644 --- a/test/hqq/test_hqq_affine.py +++ b/test/hqq/test_hqq_affine.py @@ -53,12 +53,10 @@ def _eval_hqq(dtype): dummy_linear.weight.data = W if dtype == torch.uint4: config = int4_weight_only(group_size=max(block_size), use_hqq=True) - quantize_(dummy_linear, config) - q_tensor_hqq = dummy_linear.weight else: - q_tensor_hqq = uintx_weight_only( - dtype, group_size=max(block_size), use_hqq=True - )(dummy_linear).weight + config = uintx_weight_only(dtype, group_size=max(block_size), use_hqq=True) + quantize_(dummy_linear, config) + q_tensor_hqq = dummy_linear.weight quant_linear_layer = torch.nn.Linear( W.shape[1], W.shape[0], bias=False, device=W.device diff --git a/test/quantization/test_quant_api.py b/test/quantization/test_quant_api.py index 4cb0ee3579..a53f47ac14 100644 --- a/test/quantization/test_quant_api.py +++ b/test/quantization/test_quant_api.py @@ -33,11 +33,14 @@ float8_dynamic_activation_float8_weight, float8_static_activation_float8_weight, float8_weight_only, + fpx_weight_only, + gemlite_uintx_weight_only, int4_dynamic_activation_int4_weight, int4_weight_only, int8_dynamic_activation_int4_weight, int8_dynamic_activation_int8_weight, int8_weight_only, + uintx_weight_only, ) from torchao.quantization.quant_primitives import MappingType from torchao.quantization.subclass import ( @@ -55,6 +58,13 @@ unwrap_tensor_subclass, ) +try: + import gemlite # noqa: F401 + + has_gemlite = True +except ModuleNotFoundError: + has_gemlite = False + def dynamic_quant(model, example_inputs): m = torch.export.export(model, example_inputs, strict=True).module() @@ -804,6 +814,9 @@ def test_int4wo_cpu(self, dtype, x_dim): int8_dynamic_activation_int8_weight(), int8_dynamic_activation_int4_weight(), int8_weight_only(), + fpx_weight_only(ebits=4, mbits=3), + gemlite_uintx_weight_only(), + uintx_weight_only(dtype=torch.uint4), ], ) def test_workflow_e2e_numerics(self, config): @@ -827,17 +840,23 @@ def test_workflow_e2e_numerics(self, config): and is_sm_at_least_90() ): return unittest.skip("only supported on CUDA capability 8.9, not greater") + elif isinstance(config, gemlite_uintx_weight_only) and not has_gemlite: + return unittest.skip("gemlite not available") # scale has to be moved to cuda here because the parametrization init # code happens before gating for cuda availability if isinstance(config, float8_static_activation_float8_weight): config.scale = config.scale.to("cuda") + dtype = torch.bfloat16 + if isinstance(config, gemlite_uintx_weight_only): + dtype = torch.float16 + # set up inputs - x = torch.randn(128, 128, device="cuda", dtype=torch.bfloat16) + x = torch.randn(128, 128, device="cuda", dtype=dtype) # TODO(future): model in float32 leads to error: https://gist.github.com/vkuzo/63b3bcd7818393021a6e3fb4ccf3c469 # is that expected? - m_ref = torch.nn.Sequential(torch.nn.Linear(128, 128)).cuda().bfloat16() + m_ref = torch.nn.Sequential(torch.nn.Linear(128, 128)).cuda().to(dtype) m_q = copy.deepcopy(m_ref) # quantize diff --git a/torchao/quantization/__init__.py b/torchao/quantization/__init__.py index a1d8bda058..5f15a6bbbe 100644 --- a/torchao/quantization/__init__.py +++ b/torchao/quantization/__init__.py @@ -49,11 +49,14 @@ Float8DynamicActivationFloat8WeightConfig, Float8StaticActivationFloat8WeightConfig, Float8WeightOnlyConfig, + FPXWeightOnlyConfig, + GemliteUIntXWeightOnlyConfig, Int4DynamicActivationInt4WeightConfig, Int4WeightOnlyConfig, Int8DynamicActivationInt4WeightConfig, Int8DynamicActivationInt8WeightConfig, Int8WeightOnlyConfig, + UIntXWeightOnlyConfig, float8_dynamic_activation_float8_weight, float8_static_activation_float8_weight, float8_weight_only, @@ -135,6 +138,9 @@ "Float8WeightOnlyConfig", "Float8DynamicActivationFloat8WeightConfig", "Float8StaticActivationFloat8WeightConfig", + "UIntXWeightOnlyConfig", + "FPXWeightOnlyConfig", + "GemliteUIntXWeightOnlyConfig", # smooth quant - subject to change "get_scale", "SmoothFakeDynQuantMixin", diff --git a/torchao/quantization/quant_api.py b/torchao/quantization/quant_api.py index 60ee0384c9..e347529929 100644 --- a/torchao/quantization/quant_api.py +++ b/torchao/quantization/quant_api.py @@ -729,12 +729,8 @@ def _int4_dynamic_activation_int4_weight_transform( return module -def gemlite_uintx_weight_only( - group_size: Optional[int] = 64, - bit_width: int = 4, - packing_bitwidth: int = 32, - contiguous: Optional[bool] = None, -): +@dataclass +class GemliteUIntXWeightOnlyConfig(AOBaseConfig): """ applies weight only 4 or 8 bit integer quantization and utilizes the gemlite triton kernel and its associated weight packing format. This only works for fp16 models. 8 bit quantization is symmetric, 4 bit quantization is asymmetric. @@ -747,16 +743,39 @@ def gemlite_uintx_weight_only( `contiguous`: if set, the weight will be packed as specified. Leaving it as None lets gemlite determine the best choice. """ + group_size: Optional[int] = 64 + bit_width: int = 4 + packing_bitwidth: int = 32 + contiguous: Optional[bool] = None + + +# for BC +gemlite_uintx_weight_only = GemliteUIntXWeightOnlyConfig + + +@register_quantize_module_handler(GemliteUIntXWeightOnlyConfig) +def _gemlite_uintx_weight_only_transform( + module: torch.nn.Module, config: GemliteUIntXWeightOnlyConfig +): + group_size = config.group_size + bit_width = config.bit_width + packing_bitwidth = config.packing_bitwidth + contiguous = config.contiguous + + weight = module.weight + from torchao.dtypes.uintx.gemlite_layout import get_gemlite_aqt_kwargs use_hqq = True if bit_width == 4 else False - apply_fn = lambda weight: to_affine_quantized_intx( + new_weight = to_affine_quantized_intx( weight, **get_gemlite_aqt_kwargs( weight, group_size, bit_width, packing_bitwidth, contiguous, use_hqq ), ) - return _get_linear_subclass_inserter(apply_fn) + module.weight = torch.nn.Parameter(new_weight, requires_grad=False) + module.extra_repr = types.MethodType(_linear_extra_repr, module) + return module @dataclass @@ -1380,9 +1399,10 @@ def _float8_static_activation_float8_weight_transform( return module -def uintx_weight_only(dtype, group_size=64, pack_dim=-1, use_hqq=False): +@dataclass +class UIntXWeightOnlyConfig(AOBaseConfig): """ - Applies uintx weight-only asymmetric per-group quantization to linear layers, using uintx quantization where + Configuration for applying uintx weight-only asymmetric per-group quantization to linear layers, using uintx quantization where x is the number of bits specified by `dtype` Args: @@ -1392,6 +1412,28 @@ def uintx_weight_only(dtype, group_size=64, pack_dim=-1, use_hqq=False): `pack_dim`: the dimension we use for packing, defaults to -1 `use_hqq`: whether to use hqq algorithm or the default algorithm to quantize the weight """ + + dtype: torch.dtype + group_size: int = 64 + pack_dim: int = -1 + use_hqq: bool = False + + +# for BC +uintx_weight_only = UIntXWeightOnlyConfig + + +@register_quantize_module_handler(UIntXWeightOnlyConfig) +def _uintx_weight_only_transform( + module: torch.nn.Module, config: UIntXWeightOnlyConfig +): + dtype = config.dtype + group_size = config.group_size + pack_dim = config.pack_dim + use_hqq = config.use_hqq + + weight = module.weight + from torchao.quantization.quant_primitives import _DTYPE_TO_QVALUE_BOUNDS SUPPORTED_DTYPES = { @@ -1406,49 +1448,50 @@ def uintx_weight_only(dtype, group_size=64, pack_dim=-1, use_hqq=False): } assert dtype in SUPPORTED_DTYPES, f"Unsupported dtype for hqq: {dtype}" - def apply_uintx_weight_only_quant(weight, dtype): - mapping_type = MappingType.ASYMMETRIC - block_size = (1, group_size) - - if use_hqq: - if dtype == torch.uint4: - logger.warn( - "Recommended to use `int4_weight_only(group_size, use_hqq=True)` for the best performance" - ) - quant_min, quant_max = _DTYPE_TO_QVALUE_BOUNDS[dtype] - dtype = torch.uint8 - eps = None - zero_point_dtype = None - zero_point_domain = ZeroPointDomain.FLOAT - preserve_zero = False - _layout = PlainLayout() - else: - quant_min, quant_max = None, None - eps = torch.finfo(torch.float32).eps - zero_point_dtype = torch.int32 - zero_point_domain = ZeroPointDomain.INT - preserve_zero = True - _layout = UintxLayout(dtype=dtype, pack_dim=pack_dim) + mapping_type = MappingType.ASYMMETRIC + block_size = (1, group_size) - return to_affine_quantized_intx( - weight, - mapping_type, - block_size, - dtype, - quant_min=quant_min, - quant_max=quant_max, - eps=eps, - zero_point_dtype=zero_point_dtype, - zero_point_domain=zero_point_domain, - preserve_zero=preserve_zero, - _layout=_layout, - use_hqq=use_hqq, - ) + if use_hqq: + if dtype == torch.uint4: + logger.warn( + "Recommended to use `int4_weight_only(group_size, use_hqq=True)` for the best performance" + ) + quant_min, quant_max = _DTYPE_TO_QVALUE_BOUNDS[dtype] + dtype = torch.uint8 + eps = None + zero_point_dtype = None + zero_point_domain = ZeroPointDomain.FLOAT + preserve_zero = False + _layout = PlainLayout() + else: + quant_min, quant_max = None, None + eps = torch.finfo(torch.float32).eps + zero_point_dtype = torch.int32 + zero_point_domain = ZeroPointDomain.INT + preserve_zero = True + _layout = UintxLayout(dtype=dtype, pack_dim=pack_dim) - return _get_linear_subclass_inserter(apply_uintx_weight_only_quant, dtype=dtype) + new_weight = to_affine_quantized_intx( + weight, + mapping_type, + block_size, + dtype, + quant_min=quant_min, + quant_max=quant_max, + eps=eps, + zero_point_dtype=zero_point_dtype, + zero_point_domain=zero_point_domain, + preserve_zero=preserve_zero, + _layout=_layout, + use_hqq=use_hqq, + ) + module.weight = torch.nn.Parameter(new_weight, requires_grad=False) + module.extra_repr = types.MethodType(_linear_extra_repr, module) + return module -def fpx_weight_only(ebits: int, mbits: int): +@dataclass +class FPXWeightOnlyConfig(AOBaseConfig): """Sub-byte floating point dtypes defined by `ebits`: exponent bits and `mbits`: mantissa bits e.g. fp6_e3_m2, fp6_e2_m3, ... The packing format and kernels are from the fp6-llm paper: https://arxiv.org/abs/2401.14112 @@ -1459,26 +1502,40 @@ def fpx_weight_only(ebits: int, mbits: int): in the future """ - def apply_quant_llm(weight: torch.Tensor) -> torch.Tensor: - from torchao.dtypes import to_affine_quantized_fpx - from torchao.dtypes.floatx import FloatxTensorCoreLayout + ebits: int + mbits: int - assert ( - weight.dim() == 2 - ), f"floatx only works for 2-d Tensor, got: {weight.dim()}" - out_dim, in_dim = weight.shape - if (in_dim % 64 != 0) or (out_dim % 256 != 0): - logger.info( - f"Skipping floatx quantization float{ebits + mbits + 1}_{ebits}_{mbits} because " - f"the shape is not compatible with the kernel: in_dim={in_dim}, out_dim={out_dim} " - "expected in_dim % 64 == 0 and out_dim % 256 == 0" - ) - return weight - _layout = FloatxTensorCoreLayout(ebits, mbits) - return to_affine_quantized_fpx(weight, _layout) +# for BC +fpx_weight_only = FPXWeightOnlyConfig + + +@register_quantize_module_handler(FPXWeightOnlyConfig) +def _fpx_weight_only_transform( + module: torch.nn.Module, config: FPXWeightOnlyConfig +) -> torch.nn.Module: + ebits = config.ebits + mbits = config.mbits + weight = module.weight + + from torchao.dtypes import to_affine_quantized_fpx + from torchao.dtypes.floatx import FloatxTensorCoreLayout - return _get_linear_subclass_inserter(apply_quant_llm) + assert weight.dim() == 2, f"floatx only works for 2-d Tensor, got: {weight.dim()}" + out_dim, in_dim = weight.shape + if (in_dim % 64 != 0) or (out_dim % 256 != 0): + logger.info( + f"Skipping floatx quantization float{ebits + mbits + 1}_{ebits}_{mbits} because " + f"the shape is not compatible with the kernel: in_dim={in_dim}, out_dim={out_dim} " + "expected in_dim % 64 == 0 and out_dim % 256 == 0" + ) + return module + + _layout = FloatxTensorCoreLayout(ebits, mbits) + new_weight = to_affine_quantized_fpx(weight, _layout) + module.weight = torch.nn.Parameter(new_weight, requires_grad=False) + module.extra_repr = types.MethodType(_linear_extra_repr, module) + return module if TORCH_VERSION_AT_LEAST_2_5: