Skip to content

Commit

Permalink
[wip] configs configs configs!
Browse files Browse the repository at this point in the history
Summary:

POC for:

* decoupling configuration from transformation
* stop passing obscure stateful callables around
* enable printing of configuration
* reduce amount of context switching to navigate the logic from `quantize_` to
  quantizing a single module

TODO more polish before wider discussion.

Test Plan:

```
pytest test/quantization/test_quant_api.py -s -x -k test_int4_weight_only_numerics
pytest test/quantization/test_qat.py -s -x -k test_quantize_api_standalone
pytest test/quantization/test_qat.py -s -x -k test_quantize_api_convert_path
```

Reviewers:

Subscribers:

Tasks:

Tags:

ghstack-source-id: a3f8480b19296a5a414a02939125d1a2f07d3def
ghstack-comment-id: 2607756510
Pull Request resolved: #1595
  • Loading branch information
vkuzo committed Feb 14, 2025
1 parent d3306b2 commit 07fb98c
Show file tree
Hide file tree
Showing 11 changed files with 334 additions and 152 deletions.
25 changes: 21 additions & 4 deletions test/dtypes/test_affine_quantized.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
run_tests,
)

from torchao.core.config import AOBaseConfig
from torchao.dtypes import CutlassInt4PackedLayout, Int4CPULayout, SemiSparseLayout
from torchao.quantization import (
float8_weight_only,
Expand All @@ -16,6 +17,7 @@
int8_dynamic_activation_int4_weight,
int8_dynamic_activation_int8_weight,
int8_weight_only,
quantize_,
)
from torchao.quantization.quant_primitives import MappingType, ZeroPointDomain
from torchao.utils import (
Expand Down Expand Up @@ -82,7 +84,8 @@ def test_tensor_core_layout_transpose(self):
t = linear.weight
shape = t.shape
apply_int4_weight_only_quant = int4_weight_only(group_size=32)
ql = apply_int4_weight_only_quant(linear)
quantize_(linear, apply_int4_weight_only_quant)
ql = linear
aqt = ql.weight
aqt_shape = aqt.shape
self.assertEqual(aqt_shape, shape)
Expand All @@ -102,7 +105,12 @@ def test_tensor_core_layout_transpose(self):
)
def test_weights_only(self, apply_quant):
linear = torch.nn.Linear(128, 256, dtype=torch.bfloat16, device="cuda")
ql = apply_quant(linear)
if isinstance(apply_quant, AOBaseConfig):
quantize_(linear, apply_quant)
ql = linear
else:
# TODO(#1690): delete this once config migration is done
ql = apply_quant(linear)
with tempfile.NamedTemporaryFile() as f:
torch.save(ql.state_dict(), f)
f.seek(0)
Expand Down Expand Up @@ -181,7 +189,12 @@ def apply_uint6_weight_only_quant(linear):
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
def test_print_quantized_module(self, apply_quant):
linear = torch.nn.Linear(128, 256, dtype=torch.bfloat16, device="cuda")
ql = apply_quant(linear)
if isinstance(apply_quant, AOBaseConfig):
quantize_(linear, apply_quant)
ql = linear
else:
# TODO(#1690): delete this once config migration is done
ql = apply_quant(linear)
assert "AffineQuantizedTensor" in str(ql)


Expand All @@ -195,7 +208,11 @@ def test_flatten_unflatten(self, device, dtype):
apply_quant_list = get_quantization_functions(False, True, device)
for apply_quant in apply_quant_list:
linear = torch.nn.Linear(128, 256, dtype=dtype, device=device)
ql = apply_quant(linear)
if isinstance(apply_quant, AOBaseConfig):
quantize_(linear, apply_quant)
else:
# TODO(#1690): delete this once config migration is done
ql = apply_quant(linear)
lp_tensor = ql.weight
tensor_data_name_dict, tensor_attributes = lp_tensor.__tensor_flatten__()
tensor_data_dict = {
Expand Down
7 changes: 4 additions & 3 deletions test/hqq/test_hqq_affine.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
MappingType,
ZeroPointDomain,
int4_weight_only,
quantize_,
uintx_weight_only,
)
from torchao.utils import (
Expand Down Expand Up @@ -51,9 +52,9 @@ def _eval_hqq(dtype):
)
dummy_linear.weight.data = W
if dtype == torch.uint4:
q_tensor_hqq = int4_weight_only(group_size=max(block_size), use_hqq=True)(
dummy_linear
).weight
config = int4_weight_only(group_size=max(block_size), use_hqq=True)
quantize_(dummy_linear, config)
q_tensor_hqq = dummy_linear.weight
else:
q_tensor_hqq = uintx_weight_only(
dtype, group_size=max(block_size), use_hqq=True
Expand Down
2 changes: 1 addition & 1 deletion test/quantization/test_qat.py
Original file line number Diff line number Diff line change
Expand Up @@ -1185,7 +1185,7 @@ def test_qat_prototype_bc(self):
@unittest.skipIf(
not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower"
)
def test_quantize_api(self):
def test_quantize_api_standalone(self):
"""
Test that the following:
Expand Down
25 changes: 25 additions & 0 deletions test/quantization/test_quant_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
Int4WeightOnlyQuantizedLinearWeight,
Int8WeightOnlyQuantizedLinearWeight,
)
from torchao.quantization.utils import compute_error
from torchao.utils import (
TORCH_VERSION_AT_LEAST_2_3,
TORCH_VERSION_AT_LEAST_2_4,
Expand Down Expand Up @@ -783,6 +784,30 @@ def test_int4wo_cpu(self, dtype, x_dim):
assert "_weight_int4pack_mm_for_cpu" in code[0]
assert "aten.mm.default" not in code[0]

@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "Test only enabled for 2.4+")
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
def test_int4_weight_only_numerics(self):
"""
Simple test of e2e int4_weight_only workflow, comparing numerics
to a bfloat16 baseline.
"""
# set up inputs
x = torch.randn(128, 128, device="cuda", dtype=torch.bfloat16)
# TODO(future): model in float32 leads to error: https://gist.github.com/vkuzo/63b3bcd7818393021a6e3fb4ccf3c469
# is that expected?
m_ref = torch.nn.Sequential(torch.nn.Linear(128, 128)).cuda().bfloat16()
m_int4_wo = copy.deepcopy(m_ref)

# quantize
quantize_(m_int4_wo, int4_weight_only())

with torch.no_grad():
y_ref = m_ref(x)
y_int4_wo = m_int4_wo(x)

sqnr = compute_error(y_ref, y_int4_wo)
assert sqnr >= 20, f"SQNR {sqnr} is too low"


class TestMultiTensorFlow(TestCase):
@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "Test only enabled for 2.4+")
Expand Down
Empty file added torchao/core/__init__.py
Empty file.
29 changes: 29 additions & 0 deletions torchao/core/config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
import abc


class AOBaseConfig(abc.ABC):
"""
If a workflow config inherits from this then `quantize_` knows
how to a apply it to a model. For example::
# user facing code
class WorkflowFooConfig(AOBaseConfig): ...
# configuration for workflow `Foo` is defined here
bar = 'baz'
# non user facing code
@register_quantize_module_handler(WorkflowFooConfig)
def _transform(
mod: torch.nn.Module,
config: WorkflowFooConfig,
) -> torch.nn.Module:
# the transform is implemented here, usually a tensor sublass
# weight swap or a module swap
...
# then, the user calls `quantize_` with a config, and `_transform` is called
# under the hood by `quantize_.
"""

pass
6 changes: 6 additions & 0 deletions torchao/quantization/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.


from torchao.kernel import (
int_scaled_matmul,
safe_int_mm,
Expand Down Expand Up @@ -45,6 +46,7 @@
AffineQuantizedObserverBase,
)
from .quant_api import (
Int4WeightOnlyConfig,
float8_dynamic_activation_float8_weight,
float8_static_activation_float8_weight,
float8_weight_only,
Expand Down Expand Up @@ -85,6 +87,7 @@
swap_linear_with_smooth_fq_linear,
)
from .subclass import * # noqa: F403
from .transform_module import register_quantize_module_handler
from .unified import Quantizer, TwoStepQuantizer
from .utils import (
compute_error,
Expand Down Expand Up @@ -117,6 +120,7 @@
"fpx_weight_only",
"gemlite_uintx_weight_only",
"swap_conv2d_1x1_to_linear",
"Int4WeightOnlyConfig",
# smooth quant - subject to change
"get_scale",
"SmoothFakeDynQuantMixin",
Expand Down Expand Up @@ -144,6 +148,8 @@
# operators/kernels
"safe_int_mm",
"int_scaled_matmul",
# registration of module transforms for quantize_
"register_quantize_module_handler",
# dataclasses and types
"MappingType",
"ZeroPointDomain",
Expand Down
4 changes: 4 additions & 0 deletions torchao/quantization/qat/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from .api import (
ComposableQATQuantizer,
FakeQuantizeConfig,
FromIntXQuantizationAwareTrainingConfig,
IntXQuantizationAwareTrainingConfig,
from_intx_quantization_aware_training,
intx_quantization_aware_training,
)
Expand All @@ -20,4 +22,6 @@
"Int8DynActInt4WeightQATQuantizer",
"intx_quantization_aware_training",
"from_intx_quantization_aware_training",
"FromIntXQuantizationAwareTrainingConfig",
"IntXQuantizationAwareTrainingConfig",
]
118 changes: 70 additions & 48 deletions torchao/quantization/qat/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,11 @@
# LICENSE file in the root directory of this source tree.

from dataclasses import dataclass
from typing import Any, Callable, List, Optional, Union
from typing import Any, List, Optional, Union

import torch

from torchao.core.config import AOBaseConfig
from torchao.quantization.granularity import (
Granularity,
PerAxis,
Expand All @@ -22,6 +23,9 @@
TorchAODType,
ZeroPointDomain,
)
from torchao.quantization.transform_module import (
register_quantize_module_handler,
)
from torchao.quantization.unified import TwoStepQuantizer


Expand Down Expand Up @@ -241,12 +245,26 @@ def __setattr__(self, name: str, value: Any):
super().__setattr__(name, value)


def intx_quantization_aware_training(
activation_config: Optional[FakeQuantizeConfig] = None,
weight_config: Optional[FakeQuantizeConfig] = None,
) -> Callable:
@dataclass
class IntXQuantizationAwareTrainingConfig(AOBaseConfig):
activation_config: Optional[FakeQuantizeConfig] = None
weight_config: Optional[FakeQuantizeConfig] = None


# for BC
intx_quantization_aware_training = IntXQuantizationAwareTrainingConfig


@register_quantize_module_handler(IntXQuantizationAwareTrainingConfig)
def _intx_quantization_aware_training_transform(
module: torch.nn.Module,
config: IntXQuantizationAwareTrainingConfig,
) -> torch.nn.Module:
"""
Return a function that applies fake quantization to a `torch.nn.Module`.
THIS IS NOT A PUBLIC API - any usage of this outside of torchao
can break at any time.
Apply fake quantization to a `torch.nn.Module`.
to be used with :func:`~torchao.quantization.quant_api.quantize_`.
Example usage::
Expand All @@ -261,45 +279,40 @@ def intx_quantization_aware_training(
)
quantize_(
model,
intx_quantization_aware_training(activation_config, weight_config),
IntXQuantizationAwareTrainingConfig(activation_config, weight_config),
)
Note: If the returned function is applied on a module that is not
`torch.nn.Linear` or `torch.nn.Embedding`, or it is applied on
`torch.nn.Embedding` with an activation config, then we will raise
ValueError as these are not supported.
"""

def _insert_fake_quantize(mod: torch.nn.Module):
"""
Swap the given module with its corresponding fake quantized version.
"""
from .embedding import FakeQuantizedEmbedding
from .linear import FakeQuantizedLinear

if isinstance(mod, torch.nn.Linear):
return FakeQuantizedLinear.from_linear(
mod,
activation_config,
weight_config,
)
elif isinstance(mod, torch.nn.Embedding):
if activation_config is not None:
raise ValueError(
"Activation fake quantization is not supported for embedding"
)
return FakeQuantizedEmbedding.from_embedding(mod, weight_config)
else:
from .embedding import FakeQuantizedEmbedding
from .linear import FakeQuantizedLinear

mod = module
activation_config = config.activation_config
weight_config = config.weight_config

if isinstance(mod, torch.nn.Linear):
return FakeQuantizedLinear.from_linear(
mod,
activation_config,
weight_config,
)
elif isinstance(mod, torch.nn.Embedding):
if activation_config is not None:
raise ValueError(
"Module of type '%s' does not have QAT support" % type(mod)
"Activation fake quantization is not supported for embedding"
)
return FakeQuantizedEmbedding.from_embedding(mod, weight_config)
else:
raise ValueError("Module of type '%s' does not have QAT support" % type(mod))

return _insert_fake_quantize


def from_intx_quantization_aware_training() -> Callable:
class FromIntXQuantizationAwareTrainingConfig(AOBaseConfig):
"""
Return a function that converts a model with fake quantized modules,
Object that knows how to convert a model with fake quantized modules,
such as :func:`~torchao.quantization.qat.linear.FakeQuantizedLinear`
and :func:`~torchao.quantization.qat.linear.FakeQuantizedEmbedding`,
back to model with the original, corresponding modules without
Expand All @@ -311,26 +324,35 @@ def from_intx_quantization_aware_training() -> Callable:
from torchao.quantization import quantize_
quantize_(
model_with_fake_quantized_linears,
from_intx_quantization_aware_training(),
FromIntXQuantizationAwareTrainingConfig(),
)
"""

def _remove_fake_quantize(mod: torch.nn.Module):
"""
If the given module is a fake quantized module, return the original
corresponding version of the module without fake quantization.
"""
from .embedding import FakeQuantizedEmbedding
from .linear import FakeQuantizedLinear
pass


# for BC
from_intx_quantization_aware_training = FromIntXQuantizationAwareTrainingConfig

if isinstance(mod, FakeQuantizedLinear):
return mod.to_linear()
elif isinstance(mod, FakeQuantizedEmbedding):
return mod.to_embedding()
else:
return mod

return _remove_fake_quantize
@register_quantize_module_handler(FromIntXQuantizationAwareTrainingConfig)
def _from_intx_quantization_aware_training_transform(
mod: torch.nn.Module,
config: FromIntXQuantizationAwareTrainingConfig,
) -> torch.nn.Module:
"""
If the given module is a fake quantized module, return the original
corresponding version of the module without fake quantization.
"""
from .embedding import FakeQuantizedEmbedding
from .linear import FakeQuantizedLinear

if isinstance(mod, FakeQuantizedLinear):
return mod.to_linear()
elif isinstance(mod, FakeQuantizedEmbedding):
return mod.to_embedding()
else:
return mod


class ComposableQATQuantizer(TwoStepQuantizer):
Expand Down
Loading

0 comments on commit 07fb98c

Please sign in to comment.